diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index 5d01a65d4d..d322782824 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -111,6 +111,8 @@ public void parseResponse(T response, List modelTensors, boolea if (response instanceof String && isJson((String) response)) { Map data = StringUtils.fromJson((String) response, ML_MAP_RESPONSE_KEY); modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build()); + } else if (response instanceof Map) { + modelTensors.add(ModelTensor.builder().name("response").dataAsMap((Map) response).build()); } else { Map map = new HashMap<>(); map.put("response", response); diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java index 6d075ab205..820638553e 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Map; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -289,4 +290,14 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } } + + @Override + public String toString() { + try { + return this.toXContent(JsonXContent.contentBuilder(), null).toString(); + } catch (IOException e) { + throw new IllegalArgumentException("Can't convert ModelTensor to string", e); + } + } + } diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java index ca485ec05b..2e8692048a 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java @@ -11,6 +11,7 @@ import java.util.ArrayList; import java.util.List; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; @@ -102,4 +103,13 @@ public static ModelTensorOutput parse(XContentParser parser) throws IOException return ModelTensorOutput.builder().mlModelOutputs(mlModelOutputs).build(); } + + @Override + public String toString() { + try { + return this.toXContent(JsonXContent.contentBuilder(), null).toString(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } } diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java index 8177a6ed56..97ed4cd7ca 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java @@ -13,6 +13,7 @@ import java.util.List; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -171,4 +172,13 @@ public static ModelTensors parse(XContentParser parser) throws IOException { modelTensors.setStatusCode(statusCode); return modelTensors; } + + @Override + public String toString() { + try { + return this.toXContent(JsonXContent.contentBuilder(), null).toString(); + } catch (IOException e) { + throw new IllegalArgumentException("Can't convert ModelTensor to string", e); + } + } } diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 22165c487f..d81bf217fd 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -35,6 +35,9 @@ import org.json.JSONObject; import org.opensearch.OpenSearchParseException; import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; @@ -77,7 +80,6 @@ public class StringUtils { public static final String SAFE_INPUT_DESCRIPTION = "can only contain letters, numbers, spaces, and basic punctuation (.,!?():@-_'/\")"; - public static final Gson gson = new Gson(); public static final Gson PLAIN_NUMBER_GSON = new GsonBuilder() .serializeNulls() .registerTypeAdapter(Float.class, new PlainFloatAdapter()) @@ -86,6 +88,15 @@ public class StringUtils { .registerTypeAdapter(double.class, new PlainDoubleAdapter()) .create(); + public static final Gson gson; + static { + gson = new GsonBuilder() + .disableHtmlEscaping() + .registerTypeAdapter(ModelTensor.class, new ToStringTypeAdapter<>(ModelTensor.class)) + .registerTypeAdapter(ModelTensorOutput.class, new ToStringTypeAdapter<>(ModelTensorOutput.class)) + .registerTypeAdapter(ModelTensors.class, new ToStringTypeAdapter<>(ModelTensors.class)) + .create(); + } public static final String TO_STRING_FUNCTION_NAME = ".toString()"; public static final ObjectMapper MAPPER = new ObjectMapper(); diff --git a/common/src/main/java/org/opensearch/ml/common/utils/ToStringTypeAdapter.java b/common/src/main/java/org/opensearch/ml/common/utils/ToStringTypeAdapter.java new file mode 100644 index 0000000000..3c58c2cff9 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/utils/ToStringTypeAdapter.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils; + +import java.io.IOException; + +import com.google.gson.TypeAdapter; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; + +public class ToStringTypeAdapter extends TypeAdapter { + + private final Class clazz; + + public ToStringTypeAdapter(Class clazz) { + this.clazz = clazz; + } + + @Override + public void write(JsonWriter out, T value) throws IOException { + if (value == null) { + out.nullValue(); + return; + } + String json = value.toString(); + out.jsonValue(json); + } + + @Override + public T read(JsonReader in) throws IOException { + throw new UnsupportedOperationException("Deserialization not supported"); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/utils/ToolUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/ToolUtils.java index becf53e3c8..773d0a30fc 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/ToolUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/ToolUtils.java @@ -215,4 +215,38 @@ public static String getToolName(MLToolSpec toolSpec) { return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType(); } + /** + * Converts various types of tool output into a standardized ModelTensor format. + * The conversion logic depends on the type of input: + *
    + *
  • For Map inputs: directly uses the map as data
  • + *
  • For List inputs: wraps the list in a map with "output" as the key
  • + *
  • For other types: converts to JSON string and attempts to parse as map, + * if parsing fails, wraps the original output in a map with "output" as the key
  • + *
+ * + * @param output The output object to be converted. Can be a Map, List, or any other object + * @param outputKey The key/name to be assigned to the resulting ModelTensor + * @return A ModelTensor containing the formatted output data + */ + public static ModelTensor convertOutputToModelTensor(Object output, String outputKey) { + ModelTensor modelTensor; + if (output instanceof Map) { + modelTensor = ModelTensor.builder().name(outputKey).dataAsMap((Map) output).build(); + } else if (output instanceof List) { + Map resultMap = new HashMap<>(); + resultMap.put("output", output); + modelTensor = ModelTensor.builder().name(outputKey).dataAsMap(resultMap).build(); + } else { + String outputJson = StringUtils.toJson(output); + Map resultMap; + if (StringUtils.isJson(outputJson)) { + resultMap = StringUtils.fromJson(outputJson, "output"); + } else { + resultMap = Map.of("output", output); + } + modelTensor = ModelTensor.builder().name(outputKey).dataAsMap(resultMap).build(); + } + return modelTensor; + } } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 29ea295a17..9644ee02e6 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -304,6 +304,20 @@ public void parseResponse_NonJsonString() throws IOException { Assert.assertEquals("test output", modelTensors.get(0).getDataAsMap().get("response")); } + @Test + public void parseResponse_MapResponse() throws IOException { + HttpConnector connector = createHttpConnector(); + Map responseMap = new HashMap<>(); + responseMap.put("key1", "value1"); + responseMap.put("key2", "value2"); + List modelTensors = new ArrayList<>(); + + connector.parseResponse(responseMap, modelTensors, false); + Assert.assertEquals(1, modelTensors.size()); + Assert.assertEquals("response", modelTensors.get(0).getName()); + Assert.assertEquals(responseMap, modelTensors.get(0).getDataAsMap()); + } + @Test public void fillNullParameters() { HttpConnector connector = createHttpConnector(); @@ -442,4 +456,33 @@ public void parse_WithTenantId() throws IOException { Assert.assertEquals("test_tenant", connector.getTenantId()); } + @Test + public void testParseResponse_MapResponse() throws IOException { + HttpConnector connector = createHttpConnector(); + + Map responseMap = new HashMap<>(); + responseMap.put("result", "success"); + responseMap.put("data", Arrays.asList("item1", "item2")); + + List modelTensors = new ArrayList<>(); + connector.parseResponse(responseMap, modelTensors, false); + + Assert.assertEquals(1, modelTensors.size()); + Assert.assertEquals("response", modelTensors.get(0).getName()); + Assert.assertEquals(responseMap, modelTensors.get(0).getDataAsMap()); + } + + @Test + public void testParseResponse_NonStringNonMapResponse() throws IOException { + HttpConnector connector = createHttpConnector(); + + Integer numericResponse = 42; + List modelTensors = new ArrayList<>(); + connector.parseResponse(numericResponse, modelTensors, false); + + Assert.assertEquals(1, modelTensors.size()); + Assert.assertEquals("response", modelTensors.get(0).getName()); + Assert.assertEquals(42, modelTensors.get(0).getDataAsMap().get("response")); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java index b4a79afb98..c52ecbcc4a 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java @@ -170,6 +170,14 @@ public void parse_SkipIrrelevantFields() throws IOException { assertArrayEquals(new long[] { 1, 3 }, modelTensor.getShape()); } + @Test + public void test_ToString() { + String result = modelTensorOutput.toString(); + String expected = + "{\"inference_results\":[{\"output\":[{\"name\":\"test\",\"data_type\":\"FLOAT32\",\"shape\":[1,3],\"data\":[1.0,2.0,3.0],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}]}"; + assertEquals(expected, result); + } + private void readInputStream(ModelTensorOutput input, Consumer verify) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); input.writeTo(bytesStreamOutput); diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java index d41ba82fe4..24b3739a96 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java @@ -6,6 +6,9 @@ package org.opensearch.ml.common.output.model; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import java.io.IOException; @@ -122,4 +125,23 @@ public void test_NullDataType() { .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) .build(); } + + @Test + public void test_ToString() { + String result = modelTensor.toString(); + String expected = + "{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"},\"result\":\"test result\",\"dataAsMap\":{\"key1\":\"test value1\",\"key2\":\"test value2\"}}"; + assertEquals(expected, result); + } + + @Test + public void test_ToString_ThrowsException() throws IOException { + ModelTensor spyTensor = spy(modelTensor); + doThrow(new IOException("Mock IOException")).when(spyTensor).toXContent(any(XContentBuilder.class), any()); + + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Can't convert ModelTensor to string"); + + spyTensor.toString(); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java index f3f7f98b6c..0d7c660a42 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java @@ -8,6 +8,9 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import java.io.IOException; @@ -274,4 +277,23 @@ public void parse_SkipIrrelevantFields() throws IOException { ModelTensor modelTensor = parsedTensors.getMlModelTensors().get(0); assertEquals("test_tensor", modelTensor.getName()); } + + @Test + public void test_ToString() { + String result = modelTensors.toString(); + String expected = + "{\"output\":[{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}"; + assertEquals(expected, result); + } + + @Test + public void test_ToString_ThrowsException() throws IOException { + ModelTensors spyTensors = spy(modelTensors); + doThrow(new IOException("Mock IOException")).when(spyTensors).toXContent(any(XContentBuilder.class), any()); + + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Can't convert ModelTensor to string"); + + spyTensors.toString(); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index a64630de40..89348a407f 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -31,6 +31,10 @@ import org.junit.Test; import org.opensearch.OpenSearchParseException; import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; import com.google.gson.JsonElement; import com.google.gson.TypeAdapter; @@ -244,6 +248,48 @@ public void testGetErrorMessageWhenHiddenNull() { * in the values. Verifies that the method correctly extracts the prefixes of the toString() * method calls. */ + @Test + public void testCollectToStringPrefixes() { + Map map = new HashMap<>(); + map.put("key1", "${parameters.tensor.toString()}"); + map.put("key2", "${parameters.output.toString()}"); + map.put("key3", "normal value"); + + List prefixes = StringUtils.collectToStringPrefixes(map); + + assertEquals(2, prefixes.size()); + assertTrue(prefixes.contains("tensor")); + assertTrue(prefixes.contains("output")); + } + + @Test + public void test_GsonTypeAdapters() { + // Test ModelTensor serialization + ModelTensor tensor = ModelTensor + .builder() + .name("test_tensor") + .data(new Number[] { 1, 2, 3 }) + .dataType(MLResultDataType.INT32) + .build(); + + String tensorJson = StringUtils.gson.toJson(tensor); + assertEquals(tensor.toString(), tensorJson); + + // Test ModelTensorOutput serialization + List outputs = new ArrayList<>(); + outputs.add(ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build()); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(outputs).build(); + + String outputJson = StringUtils.gson.toJson(output); + assertEquals(output.toString(), outputJson); + + // Test ModelTensors serialization + ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build(); + + String tensorsJson = StringUtils.gson.toJson(tensors); + assertEquals(tensors.toString(), tensorsJson); + } + @Test public void testGetToStringPrefix() { Map parameters = new HashMap<>(); diff --git a/common/src/test/java/org/opensearch/ml/common/utils/ToStringTypeAdapterTest.java b/common/src/test/java/org/opensearch/ml/common/utils/ToStringTypeAdapterTest.java new file mode 100644 index 0000000000..7086803057 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/utils/ToStringTypeAdapterTest.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.io.StringWriter; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +import com.google.gson.stream.JsonWriter; + +public class ToStringTypeAdapterTest { + + private ToStringTypeAdapter adapter; + private ModelTensor modelTensor; + + @Before + public void setUp() { + adapter = new ToStringTypeAdapter<>(ModelTensor.class); + modelTensor = ModelTensor.builder().name("test_tensor").data(new Number[] { 1, 2, 3 }).dataType(MLResultDataType.INT32).build(); + } + + @Test + public void test_Write_ValidObject() throws IOException { + StringWriter stringWriter = new StringWriter(); + JsonWriter jsonWriter = new JsonWriter(stringWriter); + + adapter.write(jsonWriter, modelTensor); + + String result = stringWriter.toString(); + assertEquals(modelTensor.toString(), result); + } + + @Test + public void test_Write_NullObject() throws IOException { + StringWriter stringWriter = new StringWriter(); + JsonWriter jsonWriter = new JsonWriter(stringWriter); + + adapter.write(jsonWriter, null); + + String result = stringWriter.toString(); + assertEquals("null", result); + } + + @Test + public void test_Read_ThrowsUnsupportedOperationException() { + assertThrows(UnsupportedOperationException.class, () -> { adapter.read(null); }); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/utils/ToolUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/ToolUtilsTest.java index 67e462f0e6..92e4cdf852 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/ToolUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/ToolUtilsTest.java @@ -19,6 +19,7 @@ import org.junit.Test; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.output.model.ModelTensor; public class ToolUtilsTest { @@ -343,4 +344,51 @@ public void testExtractInputParameters_NoInputParameter() { assertEquals("value1", result.get("param1")); assertEquals("value2", result.get("param2")); } + + @Test + public void testConvertOutputToModelTensor_WithMap() { + Map mapOutput = Map.of("key1", "value1", "key2", "value2"); + String outputKey = "test_output"; + + ModelTensor result = ToolUtils.convertOutputToModelTensor(mapOutput, outputKey); + + assertEquals(outputKey, result.getName()); + assertEquals(mapOutput, result.getDataAsMap()); + } + + @Test + public void testConvertOutputToModelTensor_WithList() { + List listOutput = List.of("item1", "item2", "item3"); + String outputKey = "test_output"; + + ModelTensor result = ToolUtils.convertOutputToModelTensor(listOutput, outputKey); + + assertEquals(outputKey, result.getName()); + Map expectedMap = Map.of("output", listOutput); + assertEquals(expectedMap, result.getDataAsMap()); + } + + @Test + public void testConvertOutputToModelTensor_WithJsonString() { + String jsonOutput = "{\"key1\":\"value1\",\"key2\":\"value2\"}"; + String outputKey = "test_output"; + + ModelTensor result = ToolUtils.convertOutputToModelTensor(jsonOutput, outputKey); + + assertEquals(outputKey, result.getName()); + assertTrue(result.getDataAsMap().containsKey("key1")); + assertTrue(result.getDataAsMap().containsKey("key2")); + } + + @Test + public void testConvertOutputToModelTensor_WithNonJsonString() { + String stringOutput = "simple string output"; + String outputKey = "test_output"; + + ModelTensor result = ToolUtils.convertOutputToModelTensor(stringOutput, outputKey); + + assertEquals(outputKey, result.getName()); + Map expectedMap = Map.of("output", stringOutput); + assertEquals(expectedMap, result.getDataAsMap()); + } } diff --git a/docs/tutorials/agent_framework/rag/agentic_rag_bedrock_claude.md b/docs/tutorials/agent_framework/rag/agentic_rag_bedrock_claude.md new file mode 100644 index 0000000000..fe676bded7 --- /dev/null +++ b/docs/tutorials/agent_framework/rag/agentic_rag_bedrock_claude.md @@ -0,0 +1,543 @@ +# 1. Create Model + +## 1.1 LLM + +### 1.1.1 Create LLM +``` +POST _plugins/_ml/models/_register +{ + "name": "Bedrock Claude 3.7 model", + "function_name": "remote", + "description": "test model", + "connector": { + "name": "Bedrock Claude3.7 connector", + "description": "test connector", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "us-west-2", + "service_name": "bedrock", + "model": "us.anthropic.claude-3-7-sonnet-20250219-v1:0" + }, + "credential": { + "access_key": "xxx", + "secret_key": "xxx" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/converse", + "headers": { + "content-type": "application/json" + }, + "request_body": "{ \"system\": [{\"text\": \"${parameters.system_prompt}\"}], \"messages\": [${parameters._chat_history:-}{\"role\":\"user\",\"content\":[{\"text\":\"${parameters.prompt}\"}]}${parameters._interactions:-}]${parameters.tool_configs:-} }" + } + ] + } +} +``` + +Sampel output +``` +{ + "task_id": "t8c_mJgBLapFVETfK14Y", + "status": "CREATED", + "model_id": "uMc_mJgBLapFVETfK15H" +} +``` + +### 1.1.2 Test Tool Usage + +``` +POST _plugins/_ml/models/uMc_mJgBLapFVETfK15H/_predict +{ + "parameters": { + "system_prompt": "You are a helpful assistant.", + "prompt": "What's the weather in Seattle and Beijing?", + "tool_configs": ", \"toolConfig\": {\"tools\": [${parameters._tools:-}]}", + "_tools": "{\"toolSpec\":{\"name\":\"getWeather\",\"description\":\"Get the current weather for a location\",\"inputSchema\":{\"json\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"City name, e.g. Seattle, WA\"}},\"required\":[\"location\"]}}}}", + "no_escape_params": "tool_configs,_tools" + } +} +``` +Sample output +``` +{ + "inference_results": [ + { + "output": [ + { + "name": "response", + "dataAsMap": { + "metrics": { + "latencyMs": 1917.0 + }, + "output": { + "message": { + "content": [ + { + "text": "I'll help you check the current weather in both Seattle and Beijing. Let me get that information for you." + }, + { + "toolUse": { + "input": { + "location": "Seattle, WA" + }, + "name": "getWeather", + "toolUseId": "tooluse_okU4kGWgSvm0F9KYpqUOyA" + } + } + ], + "role": "assistant" + } + }, + "stopReason": "tool_use", + "usage": { + "cacheReadInputTokenCount": 0.0, + "cacheReadInputTokens": 0.0, + "cacheWriteInputTokenCount": 0.0, + "cacheWriteInputTokens": 0.0, + "inputTokens": 407.0, + "outputTokens": 79.0, + "totalTokens": 486.0 + } + } + } + ], + "status_code": 200 + } + ] +} +``` + +Test example 2 +``` +POST /_plugins/_ml/models/uMc_mJgBLapFVETfK15H/_predict +{ + "parameters": { + "system_prompt": "You are an expert RAG (Retrieval Augmented Generation) assistant with access to OpenSearch indices. Your primary purpose is to help users find and understand information by retrieving relevant content from the OpenSearch indices.\n\nSEARCH GUIDELINES:\n1. When a user asks for information:\n - First determine which index to search (use ListIndexTool if unsure)\n - For complex questions, break them down into smaller sub-questions\n - Be specific and targeted in your search queries\n\n\nRESPONSE GUIDELINES:\n1. Always cite which index the information came from\n2. Synthesize search results into coherent, well-structured responses\n3. Use formatting (headers, bullet points) to organize information when appropriate\n4. If search results are insufficient, acknowledge limitations and suggest alternatives\n5. Maintain context from the conversation history when formulating responses\n\nPrioritize accuracy over speculation. Be professional, helpful, and concise in your interactions.", + "prompt": "How many flights from China to USA", + "tool_configs": ", \"toolConfig\": {\"tools\": [${parameters._tools:-}]}", + "_tools": "{\"toolSpec\":{\"name\":\"getWeather\",\"description\":\"Get the current weather for a location\",\"inputSchema\":{\"json\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"City name, e.g. Seattle, WA\"}},\"required\":[\"location\"]}}}}", + "no_escape_params": "tool_configs,_tools, _interactions", + "_interactions": ", {\"content\":[{\"text\":\"\\u003creasoning\\u003eThe user asks: \\\"How many flights from China to USA\\\". They want a number. Likely they need data from an index that tracks flight data. We need to search relevant index. Not sure which index exists. Let\\u0027s list indices.\\u003c/reasoning\\u003e\"},{\"toolUse\":{\"input\":{\"indices\":[]},\"name\":\"ListIndexTool\",\"toolUseId\":\"tooluse_1dXChRccQnap_Yf5Icf2vw\"}}],\"role\":\"assistant\"},{\"role\":\"user\",\"content\":[{\"toolResult\":{\"toolUseId\":\"tooluse_1dXChRccQnap_Yf5Icf2vw\",\"content\":[{\"text\":\"row,health,status,index,uuid,pri(number of primary shards),rep(number of replica shards),docs.count(number of available documents),docs.deleted(number of deleted documents),store.size(store size of primary and replica shards),pri.store.size(store size of primary shards)\\n1,green,open,.plugins-ml-model-group,HU6XuHpjRp6mPe4GYn0qqQ,1,0,4,0,11.8kb,11.8kb\\n2,green,open,otel-v1-apm-service-map-sample,KYlj9cclQXCFWKTsNGV3Og,1,0,49,0,17.7kb,17.7kb\\n3,green,open,.plugins-ml-memory-meta,CCnEW7ixS_KrJdmEl330lw,1,0,546,29,209.2kb,209.2kb\\n4,green,open,.ql-datasources,z4aXBn1oQoSXNiCMvwapkw,1,0,0,0,208b,208b\\n5,green,open,security-auditlog-2025.08.11,IwYKN08KTJqGnIJ5tVURzA,1,0,30,0,270.3kb,270.3kb\\n6,yellow,open,test_tech_news,2eA8yCDIRcaKoUhDwnD4uQ,1,1,3,0,23.1kb,23.1kb\\n7,yellow,open,test_mem_container,VS57K2JvQmyXlJxDwJ9nhA,1,1,0,0,208b,208b\\n8,green,open,.plugins-ml-task,GTX8717YRtmw0nkl1vnXMg,1,0,691,28,107.6kb,107.6kb\\n9,green,open,ss4o_metrics-otel-sample,q5W0oUlZRuaEVtfL1dknKA,1,0,39923,0,4.3mb,4.3mb\\n10,green,open,.opendistro_security,3AmSFRp1R42BSFMPfXXVqw,1,0,9,0,83.7kb,83.7kb\\n11,yellow,open,test_population_data,fkgI4WqQSv2qHSaXRJv_-A,1,1,4,0,22.5kb,22.5kb\\n12,green,open,.plugins-ml-model,Ksvz5MEwR42RmrOZAg6RTQ,1,0,18,31,406.8kb,406.8kb\\n13,green,open,.plugins-ml-memory-container,TYUOuxfMT5i4DdOMfYe45w,1,0,2,0,9.9kb,9.9kb\\n14,green,open,.kibana_-969161597_admin123_1,dlCY0iSQR9aTtRJsJx9cIg,1,0,1,0,5.3kb,5.3kb\\n15,green,open,opensearch_dashboards_sample_data_ecommerce,HdLqBN2QQCC4nUhE-5mOkA,1,0,4675,0,4.1mb,4.1mb\\n16,green,open,security-auditlog-2025.08.04,Ac34ImxkQSeonGeW8RSYcw,1,0,985,0,1mb,1mb\\n17,green,open,.plugins-ml-memory-message,osbPSzqeQd2sCPF2yqYMeg,1,0,2489,11,4mb,4mb\\n18,green,open,security-auditlog-2025.08.05,hBi-5geLRqml4mDixXMPYQ,1,0,854,0,1.1mb,1.1mb\\n19,green,open,security-auditlog-2025.08.07,bY2GJ1Z8RAG7qJb8dcHktw,1,0,5,0,75.8kb,75.8kb\\n20,green,open,otel-v1-apm-span-sample,mewMjWmHSdOlGNzgz-OnXg,1,0,13061,0,6.2mb,6.2mb\\n21,green,open,security-auditlog-2025.08.01,lvwEW0_VQnq88o7XjJkoPw,1,0,57,0,216.9kb,216.9kb\\n22,green,open,.kibana_92668751_admin_1,w45HvTqcTQ-eVTWK77lx7A,1,0,242,0,125.2kb,125.2kb\\n23,green,open,.plugins-ml-agent,Q-kWqw9_RdW3SfhUcKBcfQ,1,0,156,0,423.1kb,423.1kb\\n24,green,open,security-auditlog-2025.08.03,o_iRgiWZRryXrhQ-y36JQg,1,0,847,0,575.9kb,575.9kb\\n25,green,open,.kibana_1,tn9JwuiSSe6gR9sInHFxDg,1,0,4,0,16.7kb,16.7kb\\n26,green,open,security-auditlog-2025.08.09,uugjBoB_SR67h4455K5KPw,1,0,58,0,161.9kb,161.9kb\\n27,green,open,ss4o_logs-otel-sample,wIaGQx8wT1ijhxDjjbux6A,1,0,16286,0,5.6mb,5.6mb\\n28,green,open,.plugins-ml-config,uh1sYBiTRYy_lK61Rx0fkQ,1,0,1,0,4kb,4kb\\n29,green,open,opensearch_dashboards_sample_data_logs,jDDgZ_PfTkKHfaDqBl72lQ,1,0,14074,0,8.3mb,8.3mb\\n30,green,open,opensearch_dashboards_sample_data_flights,7dcplp7MRc2C9WTPyLRRIg,1,0,13059,0,5.8mb,5.8mb\\n\"}]}}]}" + } +} +``` +Sample response +``` +{ + "inference_results": [ + { + "output": [ + { + "name": "response", + "dataAsMap": { + "metrics": { + "latencyMs": 4082.0 + }, + "output": { + "message": { + "content": [ + { + "text": "I notice there's an index called \"opensearch_dashboards_sample_data_flights\" that might contain the information you're looking for. Let me search that index for flights from China to USA." + }, + { + "toolUse": { + "input": { + "index": "opensearch_dashboards_sample_data_flights", + "query": "OriginCountry:China AND DestCountry:\"United States\"" + }, + "name": "SearchIndexTool", + "toolUseId": "tooluse_ym7ukb5xR46h-fFW8X3h-w" + } + } + ], + "role": "assistant" + } + }, + "stopReason": "tool_use", + "usage": { + "cacheReadInputTokenCount": 0.0, + "cacheReadInputTokens": 0.0, + "cacheWriteInputTokenCount": 0.0, + "cacheWriteInputTokens": 0.0, + "inputTokens": 2417.0, + "outputTokens": 140.0, + "totalTokens": 2557.0 + } + } + } + ], + "status_code": 200 + } + ] +} +``` + +### 1.1.3 Test final resposne +``` +POST _plugins/_ml/models/uMc_mJgBLapFVETfK15H/_predict +{ + "parameters": { + "system_prompt": "You are a helpful assistant.", + "prompt": "What's the capital of USA?" + } +} +``` +Sample response +``` +{ + "inference_results": [ + { + "output": [ + { + "name": "response", + "dataAsMap": { + "metrics": { + "latencyMs": 1471.0 + }, + "output": { + "message": { + "content": [ + { + "text": "The capital of the United States of America is Washington, D.C. (District of Columbia). It's named after George Washington, the first President of the United States, and has served as the nation's capital since 1790." + } + ], + "role": "assistant" + } + }, + "stopReason": "end_turn", + "usage": { + "cacheReadInputTokenCount": 0.0, + "cacheReadInputTokens": 0.0, + "cacheWriteInputTokenCount": 0.0, + "cacheWriteInputTokens": 0.0, + "inputTokens": 20.0, + "outputTokens": 51.0, + "totalTokens": 71.0 + } + } + } + ], + "status_code": 200 + } + ] +} +``` + +### 1.1.1 Test Tool Usage + +## 1.2 Embedding Model + +### 1.2.1 Create Embedding Model +``` +POST _plugins/_ml/models/_register +{ + "name": "Bedrock embedding model", + "function_name": "remote", + "description": "Bedrock Titan Embedding Model V2", + "connector": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "us-west-2", + "service_name": "bedrock", + "model": "amazon.titan-embed-text-v2:0", + "dimensions": 1024, + "normalize": true, + "embeddingTypes": [ + "float" + ] + }, + "credential": { + "access_key": "xxx", + "secret_key": "xxx" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}, \"normalize\": ${parameters.normalize}, \"embeddingTypes\": ${parameters.embeddingTypes} }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + } +} +``` +Sample response +``` +{ + "task_id": "xsdDmJgBLapFVETfAl6Z", + "status": "CREATED", + "model_id": "x8dDmJgBLapFVETfAl63" +} +``` + +### 1.2.2 Test Embedding Model +``` +POST _plugins/_ml/models/x8dDmJgBLapFVETfAl63/_predict?algorithm=text_embedding +{ + "text_docs": [ + "hello", + "how are you" + ] +} +``` +or +``` +POST _plugins/_ml/models/x8dDmJgBLapFVETfAl63/_predict +{ + "parameters": { + "inputText": "how are you" + } +} +``` + +# 2. Agent +## 2.1 Flow Agent + +### 2.1.1 Create Flow Agent +``` +POST /_plugins/_ml/agents/_register +{ + "name": "Query DSL Translator Agent", + "type": "flow", + "description": "This is a demo agent for translating NLQ to OpenSearcdh DSL", + "tools": [ + { + "type": "IndexMappingTool", + "include_output_in_agent_response": false, + "parameters": { + "input": "{\"index\": \"${parameters.index_name}\"}" + } + }, + { + "type": "SearchIndexTool", + "parameters": { + "input": "{\"index\": \"${parameters.index_name}\", \"query\": ${parameters.query_dsl} }", + "query_dsl": "{\"size\":2,\"query\":{\"match_all\":{}}}" + }, + "include_output_in_agent_response": false + }, + { + "type": "MLModelTool", + "name": "generate_os_query_dsl", + "description": "A tool to generate OpenSearch query DSL with natrual language question.", + "parameters": { + "response_filter": "$.output.message.content[0].text", + "model_id": "uMc_mJgBLapFVETfK15H", + "embedding_model_id": "x8dDmJgBLapFVETfAl63-V", + "system_prompt": "You are an OpenSearch query generator that converts natural language questions into precise OpenSearch Query DSL JSON. Your ONLY response should be the valid JSON DSL query without any explanations or additional text.\n\nFollow these rules:\n1. Analyze the index mapping and sample document first, use the exact field name, DON'T use non-existing field name in generated query DSL\n2. Analyze the question to identify search criteria, filters, sorting, and result limits\n3. Extract specific parameters (fields, values, operators, size) mentioned in the question\n4. Apply the most appropriate query type (match, match_all, term, range, bool, etc.)\n5. Return ONLY the JSON query DSL with proper formatting\n\nNEURAL SEARCH GUIDANCE:\n1. OpenSearch KNN index can be identified index settings with `\"knn\": \"true\",`; or in index mapping with any field with `\"type\": \"knn_vector\"`\n2. If search KNN index, prefer to use OpenSearch neural search query which is semantic based search, and has better accuracy.\n3. OpenSearch neural search needs embedding model id, please always use this model id \"${parameters.embedding_model_id}\"\n4. In KNN indices, embedding fields follow the pattern: {text_field_name}_embedding. For example, the raw text input is \"description\", then the generated embedding for this field will be saved into KNN field \"description_embedding\". \n5. Always exclude embedding fields from search results as they contain vector arrays that clutter responses\n6. Embedding fields can be identified in index mapping with \"type\": \"knn_vector\"\n7. OpenSearch neural search query will use embedding field (knn_vector type) and embedding model id. \n\nNEURAL SEARCH QUERY CONSTRUCTION:\nWhen constructing neural search queries, follow this pattern:\n{\n \"_source\": {\n \"excludes\": [\n \"{field_name}_embedding\"\n ]\n },\n \"query\": {\n \"neural\": {\n \"{field_name}_embedding\": {\n \"query_text\": \"your query here\",\n \"model_id\": \"${parameters.embedding_model_id}\"\n }\n }\n }\n}\n\nRESPONSE GUIDELINES:\n1. Don't return the reasoning process, just return the generated OpenSearch query.\n2. Don't wrap the generated OpenSearch query with ```json and ```\n\nExamples:\n\nQuestion: retrieve 5 documents from index test_data\n{\"query\":{\"match_all\":{}},\"size\":5}\n\nQuestion: find documents where the field title contains machine learning\n{\"query\":{\"match\":{\"title\":\"machine learning\"}}}\n\nQuestion: search for documents with the phrase artificial intelligence in the content field and return top 10 results\n{\"query\":{\"match_phrase\":{\"content\":\"artificial intelligence\"}},\"size\":10}\n\nQuestion: get documents where price is greater than 100 and category is electronics\n{\"query\":{\"bool\":{\"must\":[{\"range\":{\"price\":{\"gt\":100}}},{\"term\":{\"category\":\"electronics\"}}]}}}\n\nQuestion: find the average rating of products in the electronics category\n{\"query\":{\"term\":{\"category\":\"electronics\"}},\"aggs\":{\"avg_rating\":{\"avg\":{\"field\":\"rating\"}}},\"size\":0}\n\nQuestion: return documents sorted by date in descending order, limit to 15 results\n{\"query\":{\"match_all\":{}},\"sort\":[{\"date\":{\"order\":\"desc\"}}],\"size\":15}\n\nQuestion: which book has the introduction of AWS AgentCore\n{\"_source\":{\"excludes\":[\"book_content_embedding\"]},\"query\":{\"neural\":{\"book_content_embedding\":{\"query_text\":\"which book has the introduction of AWS AgentCore\"}}}}\n\nQuestion: how many books published in 2024\n{\"query\": {\"term\": {\"publication_year\": 2024}},\"size\": 0,\"track_total_hits\": true}\n", + "prompt": "The index mappoing of ${parameters.index_name}:\n${parameters.IndexMappingTool.output:-}\n\nThe sample documents of ${parameters.index_name}:\n${parameters.SearchIndexTool.output:-}\n\nPlease generate the OpenSearch query dsl for the question:\n${parameters.question}" + }, + "include_output_in_agent_response": false + }, + { + "type": "SearchIndexTool", + "name": "search_index_with_llm_generated_dsl", + "include_output_in_agent_response": false, + "parameters": { + "input": "{\"index\": \"${parameters.index_name}\", \"query\": ${parameters.query_dsl} }", + "query_dsl": "${parameters.generate_os_query_dsl.output}", + "return_raw_response": true + }, + "attributes": { + "required_parameters": [ + "index_name", + "query_dsl", + "generate_os_query_dsl.output" + ] + } + } + ] +} +``` +Sample response +``` +{ + "agent_id": "y8dEmJgBLapFVETfMl4P" +} +``` +### 2.1.2 Test Flow Agent +``` +POST _plugins/_ml/agents/y8dEmJgBLapFVETfMl4P/_execute +{ + "parameters": { + "question": "How many total flights from Beijing?", + "index_name": "opensearch_dashboards_sample_data_flights" + } +} +``` +Sample response +``` +{ + "inference_results": [ + { + "output": [ + { + "name": "search_index_with_llm_generated_dsl", + "dataAsMap": { + "_shards": { + "total": 1, + "failed": 0, + "successful": 1, + "skipped": 0 + }, + "hits": { + "hits": [], + "total": { + "value": 131, + "relation": "eq" + }, + "max_score": null + }, + "took": 3, + "timed_out": false + } + } + ] + } + ] +} +``` + +## 2.2 Chat Agent + +### 2.2.1 Create Chat Agent +``` +POST _plugins/_ml/agents/_register +{ + "name": "RAG Agent", + "type": "conversational", + "description": "this is a test agent", + "app_type": "rag", + "llm": { + "model_id": "uMc_mJgBLapFVETfK15H", + "parameters": { + "max_iteration": 10, + "system_prompt": "You are an expert RAG (Retrieval Augmented Generation) assistant with access to OpenSearch indices. Your primary purpose is to help users find and understand information by retrieving relevant content from the OpenSearch indices.\n\nSEARCH GUIDELINES:\n1. When a user asks for information:\n - First determine which index to search (use ListIndexTool if unsure)\n - For complex questions, break them down into smaller sub-questions\n - Be specific and targeted in your search queries\n\n\nRESPONSE GUIDELINES:\n1. Always cite which index the information came from\n2. Synthesize search results into coherent, well-structured responses\n3. Use formatting (headers, bullet points) to organize information when appropriate\n4. If search results are insufficient, acknowledge limitations and suggest alternatives\n5. Maintain context from the conversation history when formulating responses\n\nPrioritize accuracy over speculation. Be professional, helpful, and concise in your interactions.", + "prompt": "${parameters.question}" + } + }, + "memory": { + "type": "conversation_index" + }, + "parameters": { + "_llm_interface": "bedrock/converse/claude" + }, + "tools": [ + { + "type": "ListIndexTool" + }, + { + "type": "AgentTool", + "name": "search_opensearch_index_with_nlq", + "include_output_in_agent_response": false, + "description": "This tool accepts one OpenSearch index and one natrual language question and generate OpenSearch query DSL. Then query the index with generated query DSL. If the question if complex, suggest split it into smaller questions then query one by one.", + "parameters": { + "agent_id": "y8dEmJgBLapFVETfMl4P", + "output_filter": "$.mlModelOutputs[0].mlModelTensors[2].dataAsMap" + }, + "attributes": { + "required_parameters": [ + "index_name", + "question" + ], + "input_schema": { + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "Natural language question" + }, + "index_name": { + "type": "string", + "description": "Name of the index to query" + } + }, + "required": [ + "question" + ], + "additionalProperties": false + }, + "strict": false + } + } + ] +} +``` +Sample response +``` +{ + "agent_id": "08dFmJgBLapFVETf_V6R" +} +``` + +### 2.2.2 Test Chat Agent +``` +POST /_plugins/_ml/agents/08dFmJgBLapFVETf_V6R/_execute +{ + "parameters": { + "question": "How many flights from Seattle to Canada", + "max_iteration": 30, + "verbose": true + } +} +``` +Sample response +``` +{ + "inference_results": [ + { + "output": [ + { + "name": "memory_id", + "result": "18dGmJgBLapFVETfyl6I" + }, + { + "name": "parent_interaction_id", + "result": "2MdGmJgBLapFVETfyl6a" + }, + { + "name": "response", + "result": "{\"metrics\":{\"latencyMs\":2309.0},\"output\":{\"message\":{\"content\":[{\"text\":\"I'll help you find information about flights from Seattle to Canada. Let me search for this data.\"},{\"toolUse\":{\"input\":{\"indices\":[]},\"name\":\"ListIndexTool\",\"toolUseId\":\"tooluse_QXoAl62QTYueDxzsSWmLNA\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"cacheReadInputTokenCount\":0.0,\"cacheReadInputTokens\":0.0,\"cacheWriteInputTokenCount\":0.0,\"cacheWriteInputTokens\":0.0,\"inputTokens\":907.0,\"outputTokens\":75.0,\"totalTokens\":982.0}}" + }, + { + "name": "response", + "result": "row,health,status,index,uuid,pri(number of primary shards),rep(number of replica shards),docs.count(number of available documents),docs.deleted(number of deleted documents),store.size(store size of primary and replica shards),pri.store.size(store size of primary shards)\n1,green,open,.plugins-ml-model-group,HU6XuHpjRp6mPe4GYn0qqQ,1,0,4,3,13.8kb,13.8kb\n2,green,open,otel-v1-apm-service-map-sample,KYlj9cclQXCFWKTsNGV3Og,1,0,49,0,17.7kb,17.7kb\n3,green,open,.plugins-ml-memory-meta,CCnEW7ixS_KrJdmEl330lw,1,0,569,10,126.4kb,126.4kb\n4,green,open,.ql-datasources,z4aXBn1oQoSXNiCMvwapkw,1,0,0,0,208b,208b\n5,green,open,security-auditlog-2025.08.11,IwYKN08KTJqGnIJ5tVURzA,1,0,167,0,732kb,732kb\n6,yellow,open,test_tech_news,2eA8yCDIRcaKoUhDwnD4uQ,1,1,3,0,23.1kb,23.1kb\n7,yellow,open,test_mem_container,VS57K2JvQmyXlJxDwJ9nhA,1,1,0,0,208b,208b\n8,green,open,.plugins-ml-task,GTX8717YRtmw0nkl1vnXMg,1,0,824,37,169.9kb,169.9kb\n9,green,open,ss4o_metrics-otel-sample,q5W0oUlZRuaEVtfL1dknKA,1,0,39923,0,4.3mb,4.3mb\n10,green,open,.opendistro_security,3AmSFRp1R42BSFMPfXXVqw,1,0,9,0,83.7kb,83.7kb\n11,yellow,open,test_population_data,fkgI4WqQSv2qHSaXRJv_-A,1,1,4,0,22.5kb,22.5kb\n12,green,open,.plugins-ml-model,Ksvz5MEwR42RmrOZAg6RTQ,1,0,30,2,591.1kb,591.1kb\n13,green,open,.plugins-ml-memory-container,TYUOuxfMT5i4DdOMfYe45w,1,0,2,0,9.9kb,9.9kb\n14,green,open,.kibana_-969161597_admin123_1,dlCY0iSQR9aTtRJsJx9cIg,1,0,1,0,5.3kb,5.3kb\n15,green,open,opensearch_dashboards_sample_data_ecommerce,HdLqBN2QQCC4nUhE-5mOkA,1,0,4675,0,4.1mb,4.1mb\n16,green,open,security-auditlog-2025.08.04,Ac34ImxkQSeonGeW8RSYcw,1,0,985,0,1mb,1mb\n17,green,open,.plugins-ml-memory-message,osbPSzqeQd2sCPF2yqYMeg,1,0,2662,8,4.1mb,4.1mb\n18,green,open,security-auditlog-2025.08.05,hBi-5geLRqml4mDixXMPYQ,1,0,854,0,1.1mb,1.1mb\n19,green,open,security-auditlog-2025.08.07,bY2GJ1Z8RAG7qJb8dcHktw,1,0,5,0,75.8kb,75.8kb\n20,green,open,otel-v1-apm-span-sample,mewMjWmHSdOlGNzgz-OnXg,1,0,13061,0,6.2mb,6.2mb\n21,green,open,security-auditlog-2025.08.01,lvwEW0_VQnq88o7XjJkoPw,1,0,57,0,216.9kb,216.9kb\n22,green,open,.kibana_92668751_admin_1,w45HvTqcTQ-eVTWK77lx7A,1,0,242,0,125.2kb,125.2kb\n23,green,open,.plugins-ml-agent,Q-kWqw9_RdW3SfhUcKBcfQ,1,0,167,0,483.2kb,483.2kb\n24,green,open,security-auditlog-2025.08.03,o_iRgiWZRryXrhQ-y36JQg,1,0,847,0,575.9kb,575.9kb\n25,green,open,.kibana_1,tn9JwuiSSe6gR9sInHFxDg,1,0,4,0,16.7kb,16.7kb\n26,green,open,security-auditlog-2025.08.09,uugjBoB_SR67h4455K5KPw,1,0,58,0,161.9kb,161.9kb\n27,green,open,ss4o_logs-otel-sample,wIaGQx8wT1ijhxDjjbux6A,1,0,16286,0,5.6mb,5.6mb\n28,green,open,.plugins-ml-config,uh1sYBiTRYy_lK61Rx0fkQ,1,0,1,0,4kb,4kb\n29,green,open,opensearch_dashboards_sample_data_logs,jDDgZ_PfTkKHfaDqBl72lQ,1,0,14074,0,8.3mb,8.3mb\n30,green,open,opensearch_dashboards_sample_data_flights,7dcplp7MRc2C9WTPyLRRIg,1,0,13059,0,5.8mb,5.8mb\n" + }, + { + "name": "response", + "result": "{\"metrics\":{\"latencyMs\":4234.0},\"output\":{\"message\":{\"content\":[{\"text\":\"I see that there's a flights dataset available in the index named \\\"opensearch_dashboards_sample_data_flights\\\". Let me search for flights from Seattle to Canada in this dataset.\"},{\"toolUse\":{\"input\":{\"index_name\":\"opensearch_dashboards_sample_data_flights\",\"question\":\"How many flights from Seattle to Canada?\"},\"name\":\"search_opensearch_index_with_nlq\",\"toolUseId\":\"tooluse_webI6myfTNm9r00O12tLEA\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"cacheReadInputTokenCount\":0.0,\"cacheReadInputTokens\":0.0,\"cacheWriteInputTokenCount\":0.0,\"cacheWriteInputTokens\":0.0,\"inputTokens\":2688.0,\"outputTokens\":138.0,\"totalTokens\":2826.0}}" + }, + { + "name": "response", + "result": "{\"_shards\":{\"total\":1,\"failed\":0,\"successful\":1,\"skipped\":0},\"hits\":{\"hits\":[],\"total\":{\"value\":5,\"relation\":\"eq\"}},\"took\":6,\"timed_out\":false}" + }, + { + "name": "response", + "result": "{\"metrics\":{\"latencyMs\":4182.0},\"output\":{\"message\":{\"content\":[{\"text\":\"According to the search results from the \\\"opensearch_dashboards_sample_data_flights\\\" index, there are 5 flights from Seattle to Canada.\\n\\nLet me get more details about these flights:\"},{\"toolUse\":{\"input\":{\"question\":\"Show details of flights from Seattle to Canadian destinations including destination city, carrier, and flight dates\",\"index_name\":\"opensearch_dashboards_sample_data_flights\"},\"name\":\"search_opensearch_index_with_nlq\",\"toolUseId\":\"tooluse_yNQ0uCQ_SmO_sgrPVACkdA\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"cacheReadInputTokenCount\":0.0,\"cacheReadInputTokens\":0.0,\"cacheWriteInputTokenCount\":0.0,\"cacheWriteInputTokens\":0.0,\"inputTokens\":2931.0,\"outputTokens\":152.0,\"totalTokens\":3083.0}}" + }, + { + "name": "response", + "result": "{\"_shards\":{\"total\":1,\"failed\":0,\"successful\":1,\"skipped\":0},\"hits\":{\"hits\":[{\"_index\":\"opensearch_dashboards_sample_data_flights\",\"_source\":{\"FlightNum\":\"U5MKUYM\",\"Origin\":\"Seattle Tacoma International Airport\",\"Dest\":\"Winnipeg / James Armstrong Richardson International Airport\",\"Carrier\":\"Logstash Airways\",\"timestamp\":\"2025-08-29T10:56:34\",\"DestCityName\":\"Winnipeg\"},\"_id\":\"gUNFbZgBHZOGNbY88Fea\",\"_score\":2.0},{\"_index\":\"opensearch_dashboards_sample_data_flights\",\"_source\":{\"FlightNum\":\"UF2YYSK\",\"Origin\":\"Seattle Tacoma International Airport\",\"Dest\":\"Winnipeg / James Armstrong Richardson International Airport\",\"Carrier\":\"OpenSearch Dashboards Airlines\",\"timestamp\":\"2025-08-31T09:13:05\",\"DestCityName\":\"Winnipeg\"},\"_id\":\"i0NFbZgBHZOGNbY88V0e\",\"_score\":2.0},{\"_index\":\"opensearch_dashboards_sample_data_flights\",\"_source\":{\"FlightNum\":\"GDO8L2V\",\"Origin\":\"Seattle Tacoma International Airport\",\"Dest\":\"Winnipeg / James Armstrong Richardson International Airport\",\"Carrier\":\"BeatsWest\",\"timestamp\":\"2025-09-01T17:20:52\",\"DestCityName\":\"Winnipeg\"},\"_id\":\"wkNFbZgBHZOGNbY88WKK\",\"_score\":2.0},{\"_index\":\"opensearch_dashboards_sample_data_flights\",\"_source\":{\"FlightNum\":\"E69LO59\",\"Origin\":\"Boeing Field King County International Airport\",\"Dest\":\"Edmonton International Airport\",\"Carrier\":\"OpenSearch Dashboards Airlines\",\"timestamp\":\"2025-09-04T10:14:29\",\"DestCityName\":\"Edmonton\"},\"_id\":\"h0NFbZgBHZOGNbY88Wj1\",\"_score\":2.0},{\"_index\":\"opensearch_dashboards_sample_data_flights\",\"_source\":{\"FlightNum\":\"LM8J3R1\",\"Origin\":\"Boeing Field King County International Airport\",\"Dest\":\"Montreal / Pierre Elliott Trudeau International Airport\",\"Carrier\":\"BeatsWest\",\"timestamp\":\"2025-09-06T22:10:20\",\"DestCityName\":\"Montreal\"},\"_id\":\"-0NFbZgBHZOGNbY88nKX\",\"_score\":2.0}],\"total\":{\"value\":5,\"relation\":\"eq\"},\"max_score\":2.0},\"took\":4,\"timed_out\":false}" + }, + { + "name": "response", + "result": "Based on the data from the \"opensearch_dashboards_sample_data_flights\" index, there are 5 flights from Seattle to Canada. Here are the details:\n\n### Flights from Seattle to Canada\n\n1. **Flight to Winnipeg**\n - Flight Number: U5MKUYM\n - Origin: Seattle Tacoma International Airport\n - Destination: Winnipeg / James Armstrong Richardson International Airport\n - Carrier: Logstash Airways\n - Date: August 29, 2025\n\n2. **Flight to Winnipeg**\n - Flight Number: UF2YYSK\n - Origin: Seattle Tacoma International Airport\n - Destination: Winnipeg / James Armstrong Richardson International Airport\n - Carrier: OpenSearch Dashboards Airlines\n - Date: August 31, 2025\n\n3. **Flight to Winnipeg**\n - Flight Number: GDO8L2V\n - Origin: Seattle Tacoma International Airport\n - Destination: Winnipeg / James Armstrong Richardson International Airport\n - Carrier: BeatsWest\n - Date: September 1, 2025\n\n4. **Flight to Edmonton**\n - Flight Number: E69LO59\n - Origin: Boeing Field King County International Airport (Seattle)\n - Destination: Edmonton International Airport\n - Carrier: OpenSearch Dashboards Airlines\n - Date: September 4, 2025\n\n5. **Flight to Montreal**\n - Flight Number: LM8J3R1\n - Origin: Boeing Field King County International Airport (Seattle)\n - Destination: Montreal / Pierre Elliott Trudeau International Airport\n - Carrier: BeatsWest\n - Date: September 6, 2025\n\nIn summary, there are 5 flights from Seattle to Canadian cities: 3 to Winnipeg, 1 to Edmonton, and 1 to Montreal, operated by different carriers." + } + ] + } + ] +} +``` \ No newline at end of file diff --git a/docs/tutorials/agent_framework/rag/agentic_rag_bedrock_openai_oss.md b/docs/tutorials/agent_framework/rag/agentic_rag_bedrock_openai_oss.md new file mode 100644 index 0000000000..5322092b64 --- /dev/null +++ b/docs/tutorials/agent_framework/rag/agentic_rag_bedrock_openai_oss.md @@ -0,0 +1,633 @@ +# 1. Create Model + +## 1.1 LLM + +### 1.1.1 Create LLM + +- `reasoning_effort`: "low", "medium", "high" + +``` +POST _plugins/_ml/models/_register +{ + "name": "Bedrock OpenAI GPT OSS 120b", + "function_name": "remote", + "description": "test model", + "connector": { + "name": "Bedrock OpenAI GPT OSS connector", + "description": "test connector", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "us-west-2", + "service_name": "bedrock", + "model": "openai.gpt-oss-120b-1:0", + "return_data_as_map": true, + "reasoning_effort": "high", + "output_processors": [ + { + "type": "conditional", + "path": "$.output.message.content[*].toolUse", + "routes": [ + { + "exists": [ + { + "type": "regex_replace", + "pattern": "\"stopReason\"\\s*:\\s*\"end_turn\"", + "replacement": "\"stopReason\": \"tool_use\"" + } + ] + }, + { + "not_exists": [ + { + "type": "regex_replace", + "pattern": ".*?", + "replacement": "" + } + ] + } + ] + }, + { + "type": "remove_jsonpath", + "path": "$.output.message.content[0]" + } + ] + }, + "credential": { + "access_key": "xxx", + "secret_key": "xxx" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/converse", + "headers": { + "content-type": "application/json" + }, + "request_body": "{ \"additionalModelRequestFields\": {\"reasoning_effort\": \"${parameters.reasoning_effort}\"}, \"system\": [{\"text\": \"${parameters.system_prompt}\"}], \"messages\": [${parameters._chat_history:-}{\"role\":\"user\",\"content\":[{\"text\":\"${parameters.prompt}\"}]}${parameters._interactions:-}]${parameters.tool_configs:-} }" + } + ], + "client_config": { + "max_retry_times": 5, + "retry_backoff_policy": "exponential_equal_jitter", + "retry_backoff_millis": 5000 + } + }, + "interface": {} +} +``` + +Sampel output +``` +{ + "task_id": "aPArmJgBCqG4iVqlioAh", + "status": "CREATED", + "model_id": "afArmJgBCqG4iVqlioA9" +} +``` + +### 1.1.2 Test Tool Usage + +``` +POST _plugins/_ml/models/afArmJgBCqG4iVqlioA9/_predict +{ + "parameters": { + "system_prompt": "You are a helpful assistant.", + "prompt": "What's the weather in Seattle and Beijing?", + "tool_configs": ", \"toolConfig\": {\"tools\": [${parameters._tools:-}]}", + "_tools": "{\"toolSpec\":{\"name\":\"getWeather\",\"description\":\"Get the current weather for a location\",\"inputSchema\":{\"json\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"City name, e.g. Seattle, WA\"}},\"required\":[\"location\"]}}}}", + "no_escape_params": "tool_configs,_tools" + } +} +``` +Sample output +``` +{ + "inference_results": [ + { + "output": [ + { + "name": "response", + "dataAsMap": { + "metrics": { + "latencyMs": 10011.0 + }, + "output": { + "message": { + "content": [ + { + "text": "" + }, + { + "toolUse": { + "input": { + "location": "Seattle" + }, + "name": "getWeather", + "toolUseId": "tooluse_t-ICDhbRQUyB3HQsFriRcw" + } + } + ], + "role": "assistant" + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 28.0, + "outputTokens": 36.0, + "totalTokens": 64.0 + } + } + } + ], + "status_code": 200 + } + ] +} +``` + +Test example 2 +``` +POST /_plugins/_ml/models/afArmJgBCqG4iVqlioA9/_predict +{ + "parameters": { + "system_prompt": "You are an expert RAG (Retrieval Augmented Generation) assistant with access to OpenSearch indices. Your primary purpose is to help users find and understand information by retrieving relevant content from the OpenSearch indices.\n\nSEARCH GUIDELINES:\n1. When a user asks for information:\n - First determine which index to search (use ListIndexTool if unsure)\n - For complex questions, break them down into smaller sub-questions\n - Be specific and targeted in your search queries\n\n\nRESPONSE GUIDELINES:\n1. Always cite which index the information came from\n2. Synthesize search results into coherent, well-structured responses\n3. Use formatting (headers, bullet points) to organize information when appropriate\n4. If search results are insufficient, acknowledge limitations and suggest alternatives\n5. Maintain context from the conversation history when formulating responses\n\nPrioritize accuracy over speculation. Be professional, helpful, and concise in your interactions.", + "prompt": "How many flights from China to USA", + "tool_configs": ", \"toolConfig\": {\"tools\": [${parameters._tools:-}]}", + "_tools": "{\"toolSpec\":{\"name\":\"getWeather\",\"description\":\"Get the current weather for a location\",\"inputSchema\":{\"json\":{\"type\":\"object\",\"properties\":{\"location\":{\"type\":\"string\",\"description\":\"City name, e.g. Seattle, WA\"}},\"required\":[\"location\"]}}}}", + "no_escape_params": "tool_configs,_tools, _interactions", + "_interactions": ", {\"content\":[{\"text\":\"\\u003creasoning\\u003eThe user asks: \\\"How many flights from China to USA\\\". They want a number. Likely they need data from an index that tracks flight data. We need to search relevant index. Not sure which index exists. Let\\u0027s list indices.\\u003c/reasoning\\u003e\"},{\"toolUse\":{\"input\":{\"indices\":[]},\"name\":\"ListIndexTool\",\"toolUseId\":\"tooluse_1dXChRccQnap_Yf5Icf2vw\"}}],\"role\":\"assistant\"},{\"role\":\"user\",\"content\":[{\"toolResult\":{\"toolUseId\":\"tooluse_1dXChRccQnap_Yf5Icf2vw\",\"content\":[{\"text\":\"row,health,status,index,uuid,pri(number of primary shards),rep(number of replica shards),docs.count(number of available documents),docs.deleted(number of deleted documents),store.size(store size of primary and replica shards),pri.store.size(store size of primary shards)\\n1,green,open,.plugins-ml-model-group,HU6XuHpjRp6mPe4GYn0qqQ,1,0,4,0,11.8kb,11.8kb\\n2,green,open,otel-v1-apm-service-map-sample,KYlj9cclQXCFWKTsNGV3Og,1,0,49,0,17.7kb,17.7kb\\n3,green,open,.plugins-ml-memory-meta,CCnEW7ixS_KrJdmEl330lw,1,0,546,29,209.2kb,209.2kb\\n4,green,open,.ql-datasources,z4aXBn1oQoSXNiCMvwapkw,1,0,0,0,208b,208b\\n5,green,open,security-auditlog-2025.08.11,IwYKN08KTJqGnIJ5tVURzA,1,0,30,0,270.3kb,270.3kb\\n6,yellow,open,test_tech_news,2eA8yCDIRcaKoUhDwnD4uQ,1,1,3,0,23.1kb,23.1kb\\n7,yellow,open,test_mem_container,VS57K2JvQmyXlJxDwJ9nhA,1,1,0,0,208b,208b\\n8,green,open,.plugins-ml-task,GTX8717YRtmw0nkl1vnXMg,1,0,691,28,107.6kb,107.6kb\\n9,green,open,ss4o_metrics-otel-sample,q5W0oUlZRuaEVtfL1dknKA,1,0,39923,0,4.3mb,4.3mb\\n10,green,open,.opendistro_security,3AmSFRp1R42BSFMPfXXVqw,1,0,9,0,83.7kb,83.7kb\\n11,yellow,open,test_population_data,fkgI4WqQSv2qHSaXRJv_-A,1,1,4,0,22.5kb,22.5kb\\n12,green,open,.plugins-ml-model,Ksvz5MEwR42RmrOZAg6RTQ,1,0,18,31,406.8kb,406.8kb\\n13,green,open,.plugins-ml-memory-container,TYUOuxfMT5i4DdOMfYe45w,1,0,2,0,9.9kb,9.9kb\\n14,green,open,.kibana_-969161597_admin123_1,dlCY0iSQR9aTtRJsJx9cIg,1,0,1,0,5.3kb,5.3kb\\n15,green,open,opensearch_dashboards_sample_data_ecommerce,HdLqBN2QQCC4nUhE-5mOkA,1,0,4675,0,4.1mb,4.1mb\\n16,green,open,security-auditlog-2025.08.04,Ac34ImxkQSeonGeW8RSYcw,1,0,985,0,1mb,1mb\\n17,green,open,.plugins-ml-memory-message,osbPSzqeQd2sCPF2yqYMeg,1,0,2489,11,4mb,4mb\\n18,green,open,security-auditlog-2025.08.05,hBi-5geLRqml4mDixXMPYQ,1,0,854,0,1.1mb,1.1mb\\n19,green,open,security-auditlog-2025.08.07,bY2GJ1Z8RAG7qJb8dcHktw,1,0,5,0,75.8kb,75.8kb\\n20,green,open,otel-v1-apm-span-sample,mewMjWmHSdOlGNzgz-OnXg,1,0,13061,0,6.2mb,6.2mb\\n21,green,open,security-auditlog-2025.08.01,lvwEW0_VQnq88o7XjJkoPw,1,0,57,0,216.9kb,216.9kb\\n22,green,open,.kibana_92668751_admin_1,w45HvTqcTQ-eVTWK77lx7A,1,0,242,0,125.2kb,125.2kb\\n23,green,open,.plugins-ml-agent,Q-kWqw9_RdW3SfhUcKBcfQ,1,0,156,0,423.1kb,423.1kb\\n24,green,open,security-auditlog-2025.08.03,o_iRgiWZRryXrhQ-y36JQg,1,0,847,0,575.9kb,575.9kb\\n25,green,open,.kibana_1,tn9JwuiSSe6gR9sInHFxDg,1,0,4,0,16.7kb,16.7kb\\n26,green,open,security-auditlog-2025.08.09,uugjBoB_SR67h4455K5KPw,1,0,58,0,161.9kb,161.9kb\\n27,green,open,ss4o_logs-otel-sample,wIaGQx8wT1ijhxDjjbux6A,1,0,16286,0,5.6mb,5.6mb\\n28,green,open,.plugins-ml-config,uh1sYBiTRYy_lK61Rx0fkQ,1,0,1,0,4kb,4kb\\n29,green,open,opensearch_dashboards_sample_data_logs,jDDgZ_PfTkKHfaDqBl72lQ,1,0,14074,0,8.3mb,8.3mb\\n30,green,open,opensearch_dashboards_sample_data_flights,7dcplp7MRc2C9WTPyLRRIg,1,0,13059,0,5.8mb,5.8mb\\n\"}]}}]}" + } +} +``` +Sample response +``` +{ + "inference_results": [ + { + "output": [ + { + "name": "response", + "dataAsMap": { + "metrics": { + "latencyMs": 37441.0 + }, + "output": { + "message": { + "content": [ + { + "text": "" + }, + { + "toolUse": { + "input": { + "index": "opensearch_dashboards_sample_data_flights", + "query": { + "bool": { + "must": [ + { + "match_phrase": { + "Origin": "China" + } + }, + { + "match_phrase": { + "Destination": "United States" + } + } + ] + } + }, + "size": 0.0, + "track_total_hits": true + }, + "name": "SearchTool", + "toolUseId": "tooluse_S19DlVesT3SAZ96-TmEWkA" + } + } + ], + "role": "assistant" + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 1030.0, + "outputTokens": 147.0, + "totalTokens": 1177.0 + } + } + } + ], + "status_code": 200 + } + ] +} +``` + +### 1.1.3 Test final resposne +``` +POST _plugins/_ml/models/afArmJgBCqG4iVqlioA9/_predict +{ + "parameters": { + "system_prompt": "You are a helpful assistant.", + "prompt": "What's the capital of USA?" + } +} +``` +Sample response +``` +{ + "inference_results": [ + { + "output": [ + { + "name": "response", + "dataAsMap": { + "metrics": { + "latencyMs": 797.0 + }, + "output": { + "message": { + "content": [ + { + "text": "The capital of the United States of America is **Washington, D.C.**." + } + ], + "role": "assistant" + } + }, + "stopReason": "end_turn", + "usage": { + "inputTokens": 24.0, + "outputTokens": 46.0, + "totalTokens": 70.0 + } + } + } + ], + "status_code": 200 + } + ] +} +``` + +### 1.1.1 Test Tool Usage + +## 1.2 Embedding Model + +### 1.2.1 Create Embedding Model +``` +{ + "name": "Bedrock embedding model", + "function_name": "remote", + "description": "Bedrock Titan Embedding Model V2", + "connector": { + "name": "Amazon Bedrock Connector: embedding", + "description": "The connector to bedrock Titan embedding model", + "version": 1, + "protocol": "aws_sigv4", + "parameters": { + "region": "your_aws_region", + "service_name": "bedrock", + "model": "amazon.titan-embed-text-v2:0", + "dimensions": 1024, + "normalize": true, + "embeddingTypes": [ + "float" + ] + }, + "credential": { + "access_key": "xxx", + "secret_key": "xxx" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}, \"normalize\": ${parameters.normalize}, \"embeddingTypes\": ${parameters.embeddingTypes} }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] + } +} +``` +Sample response +``` +{ + "task_id": "WfAemJgBCqG4iVqlaID3", + "status": "CREATED", + "model_id": "WvAemJgBCqG4iVqlaYAS" +} +``` + +### 1.2.2 Test Embedding Model +``` +POST _plugins/_ml/models/WvAemJgBCqG4iVqlaYAS/_predict?algorithm=text_embedding +{ + "text_docs": [ + "hello", + "how are you" + ] +} +``` +or +``` +POST _plugins/_ml/models/WvAemJgBCqG4iVqlaYAS/_predict +{ + "parameters": { + "inputText": "how are you" + } +} +``` + +# 2. Agent +## 2.1 Flow Agent + +### 2.1.1 Create Flow Agent +``` +POST /_plugins/_ml/agents/_register +{ + "name": "Query DSL Translator Agent", + "type": "flow", + "description": "This is a demo agent for translating NLQ to OpenSearcdh DSL", + "tools": [ + { + "type": "IndexMappingTool", + "include_output_in_agent_response": false, + "parameters": { + "input": "{\"index\": \"${parameters.index_name}\"}" + } + }, + { + "type": "SearchIndexTool", + "parameters": { + "input": "{\"index\": \"${parameters.index_name}\", \"query\": ${parameters.query_dsl} }", + "query_dsl": "{\"size\":2,\"query\":{\"match_all\":{}}}" + }, + "include_output_in_agent_response": false + }, + { + "type": "MLModelTool", + "name": "generate_os_query_dsl", + "description": "A tool to generate OpenSearch query DSL with natrual language question.", + "parameters": { + "model_id": "afArmJgBCqG4iVqlioA9", + "embedding_model_id": "WvAemJgBCqG4iVqlaYAS", + "system_prompt": "You are an OpenSearch query generator that converts natural language questions into precise OpenSearch Query DSL JSON. Your ONLY response should be the valid JSON DSL query without any explanations or additional text.\n\nFollow these rules:\n1. Analyze the index mapping and sample document first, use the exact field name, DON'T use non-existing field name in generated query DSL\n2. Analyze the question to identify search criteria, filters, sorting, and result limits\n3. Extract specific parameters (fields, values, operators, size) mentioned in the question\n4. Apply the most appropriate query type (match, match_all, term, range, bool, etc.)\n5. Return ONLY the JSON query DSL with proper formatting.\n6. Please use standard two-letter ISO 3166-1 alpha-2 country codes (such as CN for China, US for United States, GB for United Kingdom) when build opensearch query.\n\nNEURAL SEARCH GUIDANCE:\n1. OpenSearch KNN index can be identified index settings with `\"knn\": \"true\",`; or in index mapping with any field with `\"type\": \"knn_vector\"`\n2. If search KNN index, prefer to use OpenSearch neural search query which is semantic based search, and has better accuracy.\n3. OpenSearch neural search needs embedding model id, please always use this model id \"${parameters.embedding_model_id}\"\n4. In KNN indices, embedding fields follow the pattern: {text_field_name}_embedding. For example, the raw text input is \"description\", then the generated embedding for this field will be saved into KNN field \"description_embedding\". \n5. Always exclude embedding fields from search results as they contain vector arrays that clutter responses\n6. Embedding fields can be identified in index mapping with \"type\": \"knn_vector\"\n7. OpenSearch neural search query will use embedding field (knn_vector type) and embedding model id. \n\nNEURAL SEARCH QUERY CONSTRUCTION:\nWhen constructing neural search queries, follow this pattern:\n{\n \"_source\": {\n \"excludes\": [\n \"{field_name}_embedding\"\n ]\n },\n \"query\": {\n \"neural\": {\n \"{field_name}_embedding\": {\n \"query_text\": \"your query here\",\n \"model_id\": \"${parameters.embedding_model_id}\"\n }\n }\n }\n}\n\nRESPONSE GUIDELINES:\n1. Don't return the reasoning process, just return the generated OpenSearch query.\n2. Don't wrap the generated OpenSearch query with ```json and ```\n\nExamples:\n\nQuestion: retrieve 5 documents from index test_data\n{\"query\":{\"match_all\":{}},\"size\":5}\n\nQuestion: find documents where the field title contains machine learning\n{\"query\":{\"match\":{\"title\":\"machine learning\"}}}\n\nQuestion: search for documents with the phrase artificial intelligence in the content field and return top 10 results\n{\"query\":{\"match_phrase\":{\"content\":\"artificial intelligence\"}},\"size\":10}\n\nQuestion: get documents where price is greater than 100 and category is electronics\n{\"query\":{\"bool\":{\"must\":[{\"range\":{\"price\":{\"gt\":100}}},{\"term\":{\"category\":\"electronics\"}}]}}}\n\nQuestion: find the average rating of products in the electronics category\n{\"query\":{\"term\":{\"category\":\"electronics\"}},\"aggs\":{\"avg_rating\":{\"avg\":{\"field\":\"rating\"}}},\"size\":0}\n\nQuestion: return documents sorted by date in descending order, limit to 15 results\n{\"query\":{\"match_all\":{}},\"sort\":[{\"date\":{\"order\":\"desc\"}}],\"size\":15}\n\nQuestion: which book has the introduction of AWS AgentCore\n{\"_source\":{\"excludes\":[\"book_content_embedding\"]},\"query\":{\"neural\":{\"book_content_embedding\":{\"query_text\":\"which book has the introduction of AWS AgentCore\"}}}}\n\nQuestion: how many books published in 2024\n{\"query\": {\"term\": {\"publication_year\": 2024}},\"size\": 0,\"track_total_hits\": true}\n", + "prompt": "The index mappoing of ${parameters.index_name}:\n${parameters.IndexMappingTool.output:-}\n\nThe sample documents of ${parameters.index_name}:\n${parameters.SearchIndexTool.output:-}\n\nPlease generate the OpenSearch query dsl for the question:\n${parameters.question}", + "response_filter": "$.output.message.content[1].text", + "output_processors": [ + { + "type": "regex_replace", + "pattern": ".*?", + "replacement": "" + } + ], + "return_data_as_map": true + }, + "include_output_in_agent_response": true + }, + { + "type": "SearchIndexTool", + "name": "search_index_with_llm_generated_dsl", + "include_output_in_agent_response": false, + "parameters": { + "input": "{\"index\": \"${parameters.index_name}\", \"query\": ${parameters.query_dsl} }", + "query_dsl": "${parameters.generate_os_query_dsl.output}", + "return_raw_response": true, + "return_data_as_map": true + }, + "attributes": { + "required_parameters": [ + "index_name", + "query_dsl", + "generate_os_query_dsl.output" + ] + } + } + ] +} +``` +Sample response +``` +{ + "agent_id": "bPAsmJgBCqG4iVqlqYAR" +} +``` +### 2.1.2 Test Flow Agent +``` +POST _plugins/_ml/agents/bPAsmJgBCqG4iVqlqYAR/_execute +{ + "parameters": { + "question": "How many total flights from Beijing?", + "index_name": "opensearch_dashboards_sample_data_flights" + } +} +``` +Sample response +``` +{ + "inference_results": [ + { + "output": [ + { + "name": "generate_os_query_dsl.output", + "dataAsMap": { + "query": { + "term": { + "OriginCityName": "Beijing" + } + }, + "size": 0.0, + "track_total_hits": true + } + }, + { + "name": "search_index_with_llm_generated_dsl", + "dataAsMap": { + "_shards": { + "total": 1, + "failed": 0, + "successful": 1, + "skipped": 0 + }, + "hits": { + "hits": [], + "total": { + "value": 131, + "relation": "eq" + }, + "max_score": null + }, + "took": 1, + "timed_out": false + } + } + ] + } + ] +} +``` + +## 2.2 Chat Agent + +### 2.2.1 Create Chat Agent +``` +POST _plugins/_ml/agents/_register +{ + "name": "RAG Agent", + "type": "conversational", + "description": "this is a test agent", + "app_type": "rag", + "llm": { + "model_id": "afArmJgBCqG4iVqlioA9", + "parameters": { + "max_iteration": 10, + "system_prompt": "You are an expert RAG (Retrieval Augmented Generation) assistant with access to OpenSearch indices. Your primary purpose is to help users find and understand information by retrieving relevant content from the OpenSearch indices.\n\nSEARCH GUIDELINES:\n1. When a user asks for information:\n - First determine which index to search (use ListIndexTool if unsure)\n - For complex questions, break them down into smaller sub-questions\n - Be specific and targeted in your search queries\n\n\nRESPONSE GUIDELINES:\n1. Always cite which index the information came from\n2. Synthesize search results into coherent, well-structured responses\n3. Use formatting (headers, bullet points) to organize information when appropriate\n4. If search results are insufficient, acknowledge limitations and suggest alternatives\n5. Maintain context from the conversation history when formulating responses\n\nPrioritize accuracy over speculation. Be professional, helpful, and concise in your interactions.", + "prompt": "${parameters.question}" + } + }, + "memory": { + "type": "conversation_index" + }, + "parameters": { + "_llm_interface": "bedrock/converse/claude" + }, + "tools": [ + { + "type": "ListIndexTool" + }, + { + "type": "AgentTool", + "name": "search_opensearch_index_with_nlq", + "include_output_in_agent_response": false, + "description": "This tool accepts one OpenSearch index and one natrual language question and generate OpenSearch query DSL. Then query the index with generated query DSL. If the question if complex, suggest split it into smaller questions then query one by one.", + "parameters": { + "agent_id": "bPAsmJgBCqG4iVqlqYAR", + "output_filter": "$.mlModelOutputs[0].mlModelTensors[2].dataAsMap" + }, + "attributes": { + "required_parameters": [ + "index_name", + "question" + ], + "input_schema": { + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "Natural language question" + }, + "index_name": { + "type": "string", + "description": "Name of the index to query" + } + }, + "required": [ + "question" + ], + "additionalProperties": false + }, + "strict": false + } + } + ] +} +``` +Sample response +``` +{ + "agent_id": "bvAtmJgBCqG4iVql54Ck" +} +``` + +### 2.2.2 Test Chat Agent +``` +POST /_plugins/_ml/agents/bvAtmJgBCqG4iVql54Ck/_execute +{ + "parameters": { + "question": "How many flights from Seattle to Canada", + "max_iteration": 30, + "verbose": true + } +} +``` +Sample response +``` +{ + "inference_results": [ + { + "output": [ + { + "name": "memory_id", + "result": "hMc4mJgBLapFVETfRl5I" + }, + { + "name": "parent_interaction_id", + "result": "hcc4mJgBLapFVETfRl5n" + }, + { + "name": "response", + "result": "{\"metrics\":{\"latencyMs\":2210},\"output\":{\"message\":{\"content\":[{\"text\":\"We need to answer: number of flights from Seattle to Canada. Likely need to search an index containing flight data. Not sure what's available. Let's list indices.\"},{\"toolUse\":{\"input\":{\"indices\":[]},\"name\":\"ListIndexTool\",\"toolUseId\":\"tooluse_cjlogX--SPy4d_jLUF-1kg\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"inputTokens\":265,\"outputTokens\":52,\"totalTokens\":317}}" + }, + { + "name": "response", + "result": "row,health,status,index,uuid,pri(number of primary shards),rep(number of replica shards),docs.count(number of available documents),docs.deleted(number of deleted documents),store.size(store size of primary and replica shards),pri.store.size(store size of primary shards)\n1,green,open,.plugins-ml-model-group,HU6XuHpjRp6mPe4GYn0qqQ,1,0,4,0,6.7kb,6.7kb\n2,green,open,otel-v1-apm-service-map-sample,KYlj9cclQXCFWKTsNGV3Og,1,0,49,0,17.7kb,17.7kb\n3,green,open,.plugins-ml-memory-meta,CCnEW7ixS_KrJdmEl330lw,1,0,568,18,89kb,89kb\n4,green,open,.ql-datasources,z4aXBn1oQoSXNiCMvwapkw,1,0,0,0,208b,208b\n5,green,open,security-auditlog-2025.08.11,IwYKN08KTJqGnIJ5tVURzA,1,0,113,0,198.7kb,198.7kb\n6,yellow,open,test_tech_news,2eA8yCDIRcaKoUhDwnD4uQ,1,1,3,0,23.1kb,23.1kb\n7,yellow,open,test_mem_container,VS57K2JvQmyXlJxDwJ9nhA,1,1,0,0,208b,208b\n8,green,open,.plugins-ml-task,GTX8717YRtmw0nkl1vnXMg,1,0,818,35,215.1kb,215.1kb\n9,green,open,ss4o_metrics-otel-sample,q5W0oUlZRuaEVtfL1dknKA,1,0,39923,0,4.3mb,4.3mb\n10,green,open,.opendistro_security,3AmSFRp1R42BSFMPfXXVqw,1,0,9,0,83.7kb,83.7kb\n11,yellow,open,test_population_data,fkgI4WqQSv2qHSaXRJv_-A,1,1,4,0,22.5kb,22.5kb\n12,green,open,.plugins-ml-model,Ksvz5MEwR42RmrOZAg6RTQ,1,0,25,36,597.5kb,597.5kb\n13,green,open,.plugins-ml-memory-container,TYUOuxfMT5i4DdOMfYe45w,1,0,2,0,9.9kb,9.9kb\n14,green,open,.kibana_-969161597_admin123_1,dlCY0iSQR9aTtRJsJx9cIg,1,0,1,0,5.3kb,5.3kb\n15,green,open,opensearch_dashboards_sample_data_ecommerce,HdLqBN2QQCC4nUhE-5mOkA,1,0,4675,0,4.1mb,4.1mb\n16,green,open,security-auditlog-2025.08.04,Ac34ImxkQSeonGeW8RSYcw,1,0,985,0,1mb,1mb\n17,green,open,.plugins-ml-memory-message,osbPSzqeQd2sCPF2yqYMeg,1,0,2648,8,4mb,4mb\n18,green,open,security-auditlog-2025.08.05,hBi-5geLRqml4mDixXMPYQ,1,0,854,0,1.1mb,1.1mb\n19,green,open,security-auditlog-2025.08.07,bY2GJ1Z8RAG7qJb8dcHktw,1,0,5,0,75.8kb,75.8kb\n20,green,open,otel-v1-apm-span-sample,mewMjWmHSdOlGNzgz-OnXg,1,0,13061,0,6.2mb,6.2mb\n21,green,open,security-auditlog-2025.08.01,lvwEW0_VQnq88o7XjJkoPw,1,0,57,0,216.9kb,216.9kb\n22,green,open,.kibana_92668751_admin_1,w45HvTqcTQ-eVTWK77lx7A,1,0,242,0,125.2kb,125.2kb\n23,green,open,.plugins-ml-agent,Q-kWqw9_RdW3SfhUcKBcfQ,1,0,165,0,435.7kb,435.7kb\n24,green,open,security-auditlog-2025.08.03,o_iRgiWZRryXrhQ-y36JQg,1,0,847,0,575.9kb,575.9kb\n25,green,open,.kibana_1,tn9JwuiSSe6gR9sInHFxDg,1,0,4,0,16.7kb,16.7kb\n26,green,open,security-auditlog-2025.08.09,uugjBoB_SR67h4455K5KPw,1,0,58,0,161.9kb,161.9kb\n27,green,open,ss4o_logs-otel-sample,wIaGQx8wT1ijhxDjjbux6A,1,0,16286,0,5.6mb,5.6mb\n28,green,open,.plugins-ml-config,uh1sYBiTRYy_lK61Rx0fkQ,1,0,1,0,4kb,4kb\n29,green,open,opensearch_dashboards_sample_data_logs,jDDgZ_PfTkKHfaDqBl72lQ,1,0,14074,0,8.3mb,8.3mb\n30,green,open,opensearch_dashboards_sample_data_flights,7dcplp7MRc2C9WTPyLRRIg,1,0,13059,0,5.8mb,5.8mb\n" + }, + { + "name": "response", + "result": "{\"metrics\":{\"latencyMs\":1548},\"output\":{\"message\":{\"content\":[{\"text\":\"We need to answer: \\\"How many flights from Seattle to Canada\\\". Likely the index opensearch_dashboards_sample_data_flights contains flight data. Let's query that index.\"},{\"toolUse\":{\"input\":{\"question\":\"How many flights from Seattle to Canada\",\"index_name\":\"opensearch_dashboards_sample_data_flights\"},\"name\":\"search_opensearch_index_with_nlq\",\"toolUseId\":\"tooluse_XjafWd4RTUO72qqkwhOC_A\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"inputTokens\":1016,\"outputTokens\":78,\"totalTokens\":1094}}" + }, + { + "name": "response", + "result": "{\"query\":{\"bool\":{\"must\":[{\"term\":{\"OriginCityName\":\"Seattle\"}},{\"term\":{\"DestCountry\":\"Canada\"}}]}},\"size\":0.0,\"track_total_hits\":true}" + }, + { + "name": "response", + "result": "{\"metrics\":{\"latencyMs\":1351},\"output\":{\"message\":{\"content\":[{\"text\":\"The search returns 0 hits. The dataset may not contain such flights, or maybe field names differ. Let's open index mapping or sample data. Let's search for flights from Seattle.\"},{\"toolUse\":{\"input\":{\"question\":\"Seattle flights\",\"index_name\":\"opensearch_dashboards_sample_data_flights\"},\"name\":\"search_opensearch_index_with_nlq\",\"toolUseId\":\"tooluse_g7FFHR26RoqBhJ02kf1KcQ\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"inputTokens\":1176,\"outputTokens\":72,\"totalTokens\":1248}}" + }, + { + "name": "response", + "result": "{\"query\":{\"match\":{\"DestCityName\":\"Seattle\"}}}" + }, + { + "name": "response", + "result": "{\"metrics\":{\"latencyMs\":1229},\"output\":{\"message\":{\"content\":[{\"text\":\"We need count flights from Seattle to Canada. Flights where OriginCityName == Seattle and DestCountry == Canada. Index opensearch_dashboards_sample_data_flights.\\n\\nWe must query.\"},{\"toolUse\":{\"input\":{\"question\":\"Seattle flights to Canada\",\"index_name\":\"opensearch_dashboards_sample_data_flights\"},\"name\":\"search_opensearch_index_with_nlq\",\"toolUseId\":\"tooluse_XAI2YKD9Sx616oXtqZljzQ\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"inputTokens\":3583,\"outputTokens\":75,\"totalTokens\":3658}}" + }, + { + "name": "response", + "result": "{\"query\":{\"bool\":{\"must\":[{\"term\":{\"OriginCityName\":\"Seattle\"}},{\"term\":{\"DestCountry\":\"Canada\"}}]}}}" + }, + { + "name": "response", + "result": "{\"metrics\":{\"latencyMs\":1114},\"output\":{\"message\":{\"content\":[{\"text\":\"The user asks: \\\"How many flights from Seattle to Canada\\\". We need to query the flights dataset. The index opensearch_dashboards_sample_data_flights appears to have fields: OriginCityName, DestCountry.\\n\\nWe attempted a query with OriginCityName: Seattle and DestCountry: Canada, got zero results.\\n\\nMaybe the field names are different; maybe use OriginCityName = Seattle? Check a sample doc where OriginCityName is Seattle. Do we have any Seattle as origin? Let's search for OriginCityName \\\"Seattle\\\".\"},{\"toolUse\":{\"input\":{\"question\":\"OriginCityName Seattle\",\"index_name\":\"opensearch_dashboards_sample_data_flights\"},\"name\":\"search_opensearch_index_with_nlq\",\"toolUseId\":\"tooluse_5tEq-NJfRcGCXjnQ6o7w8Q\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"inputTokens\":3736,\"outputTokens\":154,\"totalTokens\":3890}}" + }, + { + "name": "response", + "result": "{\"query\":{\"term\":{\"OriginCityName\":\"Seattle\"}}}" + }, + { + "name": "response", + "result": "{\"metrics\":{\"latencyMs\":1433},\"output\":{\"message\":{\"content\":[{\"text\":\"The user asks: \\\"How many flights from Seattle to Canada\\\". We need to count flights where OriginCityName is Seattle and DestCountry is Canada (or DestRegion Canada). The dataset appears to be sample flight data. The earlier query with both constraints returned 0 hits. Possibly there are no flights from Seattle to Canada in this sample dataset. However we should verify maybe DestCountry field sometimes uses two-letter code? The sample shows DestCountry values like \\\"US\\\", \\\"AT\\\", \\\"IT\\\". For Canada likely \\\"CA\\\". Let's search DestCountry:CA with Origin Seattle.\"},{\"toolUse\":{\"input\":{\"question\":\"OriginCityName Seattle AND DestCountry CA\",\"index_name\":\"opensearch_dashboards_sample_data_flights\"},\"name\":\"search_opensearch_index_with_nlq\",\"toolUseId\":\"tooluse_0Sv4fGB4T_Kp1_2-k83JGA\"}}],\"role\":\"assistant\"}},\"stopReason\":\"tool_use\",\"usage\":{\"inputTokens\":6248,\"outputTokens\":174,\"totalTokens\":6422}}" + }, + { + "name": "response", + "result": "{\"query\":{\"bool\":{\"must\":[{\"term\":{\"OriginCityName\":\"Seattle\"}},{\"term\":{\"DestCountry\":\"CA\"}}]}}}" + }, + { + "name": "response", + "result": "**Answer**\n\nThere are **5 flights** in the dataset that originate from Seattle and have a destination in Canada (DestCountry = “CA”).\n\n**Source**\n\n- Data retrieved from the **`opensearch_dashboards_sample_data_flights`** index. The query filtered for `OriginCityName = \"Seattle\"` and `DestCountry = \"CA\"` and returned five matching documents." + } + ] + } + ] +} +``` \ No newline at end of file diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 8f14f54478..1a54c5a3b7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -321,7 +321,7 @@ public static Map parseLLMOutput( dataAsMap = removeJsonPath(dataAsMap, llmResponseExcludePath, true); } if (dataAsMap.size() == 1 && dataAsMap.containsKey(RESPONSE_FIELD)) { - String llmReasoningResponse = (String) dataAsMap.get(RESPONSE_FIELD); + String llmReasoningResponse = StringUtils.toJson(dataAsMap.get(RESPONSE_FIELD)); String thoughtResponse = null; try { thoughtResponse = extractModelResponseJson(llmReasoningResponse, llmResponsePatterns); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index 5a4ce8e796..e6685ca619 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID; import static org.opensearch.ml.common.conversation.ActionConstants.PARENT_INTERACTION_ID_FIELD; import static org.opensearch.ml.common.utils.ToolUtils.TOOL_OUTPUT_FILTERS_FIELD; +import static org.opensearch.ml.common.utils.ToolUtils.convertOutputToModelTensor; import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput; import static org.opensearch.ml.common.utils.ToolUtils.getToolName; import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; @@ -19,9 +20,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; -import java.security.AccessController; import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -270,12 +269,10 @@ private void processOutput( flowAgentOutput.add(ModelTensor.builder().name(outputKey).result(filteredOutput).build()); } else if (output instanceof ModelTensorOutput) { flowAgentOutput.addAll(((ModelTensorOutput) output).getMlModelOutputs().get(0).getMlModelTensors()); + } else if (toolParameters.getOrDefault("return_data_as_map", "false").equalsIgnoreCase("true")) { + flowAgentOutput.add(convertOutputToModelTensor(output, outputKey)); } else { - String result = output instanceof String - ? (String) output - : AccessController.doPrivileged((PrivilegedExceptionAction) () -> StringUtils.toJson(output)); - - ModelTensor stepOutput = ModelTensor.builder().name(toolName).result(result).build(); + ModelTensor stepOutput = ModelTensor.builder().name(toolName).result(StringUtils.toJson(output)).build(); flowAgentOutput.add(stepOutput); } if (memory == null) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index 4708b4a8e4..585bb0476e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.agent; import static org.opensearch.ml.common.utils.ToolUtils.TOOL_OUTPUT_FILTERS_FIELD; +import static org.opensearch.ml.common.utils.ToolUtils.convertOutputToModelTensor; import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput; import static org.opensearch.ml.common.utils.ToolUtils.getToolName; import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; @@ -120,6 +121,8 @@ public void run(MLAgent mlAgent, Map params, ActionListener params, ActionListener newParameters = new HashMap<>(); String noEscapeParams = inputData.getParameters().get(NO_ESCAPE_PARAMS); Set noEscapParamSet = new HashSet<>(); @@ -250,11 +254,37 @@ public static ModelTensors processOutput( boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && org.opensearch.ml.common.utils.StringUtils.isJson(response); - if (responseFilter == null) { - connector.parseResponse(response, modelTensors, scriptReturnModelTensor); + + // Apply output processor chain if configured + Object processedOutput; + // Apply output processor chain if configured + List> processorConfigs = ProcessorChain.extractProcessorConfigs(parameters); + if (!processorConfigs.isEmpty()) { + ProcessorChain processorChain = new ProcessorChain(processorConfigs); + + if (responseFilter != null) { + // Apply filter first, then processor chain + Object filteredResponse = JsonPath.parse(response).read(responseFilter); + processedOutput = processorChain.process(filteredResponse); + } else { + // Apply processor chain to whole response + processedOutput = processorChain.process(response); + } + + // Handle the processed output + if (processedOutput instanceof String) { + connector.parseResponse((String) processedOutput, modelTensors, scriptReturnModelTensor); + } else { + connector.parseResponse(processedOutput, modelTensors, scriptReturnModelTensor); + } } else { - Object filteredResponse = JsonPath.parse(response).read(parameters.get(RESPONSE_FILTER_FIELD)); - connector.parseResponse(filteredResponse, modelTensors, scriptReturnModelTensor); + // Original flow without processor chain + if (responseFilter == null) { + connector.parseResponse(response, modelTensors, scriptReturnModelTensor); + } else { + Object filteredResponse = JsonPath.parse(response).read(responseFilter); + connector.parseResponse(filteredResponse, modelTensors, scriptReturnModelTensor); + } } return ModelTensors.builder().mlModelTensors(modelTensors).build(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/processor/ProcessorChain.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/processor/ProcessorChain.java new file mode 100644 index 0000000000..35931a2e21 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/processor/ProcessorChain.java @@ -0,0 +1,556 @@ +/* +* Copyright OpenSearch Contributors +* SPDX-License-Identifier: Apache-2.0 +*/ + +package org.opensearch.ml.engine.processor; + +import static org.opensearch.ml.common.utils.StringUtils.gson; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.apache.commons.text.StringEscapeUtils; +import org.opensearch.ml.common.utils.StringUtils; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.reflect.TypeToken; +import com.google.gson.JsonSyntaxException; +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.PathNotFoundException; + +import lombok.extern.log4j.Log4j2; +import net.minidev.json.JSONArray; + +/** + * Common framework for processing outputs from ML models and tools + */ +@Log4j2 +public class ProcessorChain { + + public static final String OUTPUT_PROCESSORS = "output_processors"; + public static final String TO_STRING = "to_string"; + public static final String REGEX_REPLACE = "regex_replace"; + public static final String JSONPATH_FILTER = "jsonpath_filter"; + public static final String EXTRACT_JSON = "extract_json"; + public static final String REGEX_CAPTURE = "regex_capture"; + + /** + * Interface for customized output processors + */ + public interface OutputProcessor { + /** + * Process the input value and return the processed result + * @param input The object to process + * @return The processed result + */ + Object process(Object input); + } + + /** + * Registry for creating processor instances from configuration + */ + public static class ProcessorRegistry { + private static final Map, OutputProcessor>> PROCESSORS = new HashMap<>(); + + static { + // Register all available processors + registerDefaultProcessors(); + } + + /** + * Helper method to apply a list of processors to an input + */ + private static Object applyProcessors(Object input, List processors) { + Object result = input; + for (OutputProcessor processor : processors) { + result = processor.process(result); + } + return result; + } + + /** + * Parse processor configurations into a list of processor instances + */ + @SuppressWarnings("unchecked") + private static List parseProcessorConfigs(Object config) { + if (config == null) { + return Collections.emptyList(); + } + + List> processorConfigs; + if (config instanceof Map) { + processorConfigs = Collections.singletonList((Map) config); + } else if (config instanceof List) { + processorConfigs = (List>) config; + } else { + log.warn("Invalid processor configuration: {}", config); + return Collections.emptyList(); + } + + return ProcessorRegistry.createProcessingChain(processorConfigs); + } + + /** + * Check if a value matches the specified condition + */ + private static boolean matchesCondition(String condition, Object value) { + // Handle null value cases + if (value == null || (value instanceof JSONArray && ((JSONArray) value).isEmpty())) { + return "null".equals(condition) || "not_exists".equals(condition); + } + + // Handle existence condition + if ("exists".equals(condition)) { + return true; + } + + // Handle exact value match + String strValue = value.toString(); + if (condition.equals(strValue)) { + return true; + } + + // Handle numeric conditions + if (value instanceof Number || canParseAsNumber(strValue)) { + double numValue; + if (value instanceof Number) { + numValue = ((Number) value).doubleValue(); + } else { + try { + numValue = Double.parseDouble(strValue); + } catch (NumberFormatException e) { + return false; + } + } + + // Check numeric conditions + if (condition.startsWith(">") && !condition.startsWith(">=")) { + double threshold = Double.parseDouble(condition.substring(1)); + return numValue > threshold; + } else if (condition.startsWith("<") && !condition.startsWith("<=")) { + double threshold = Double.parseDouble(condition.substring(1)); + return numValue < threshold; + } else if (condition.startsWith(">=")) { + double threshold = Double.parseDouble(condition.substring(2)); + return numValue >= threshold; + } else if (condition.startsWith("<=")) { + double threshold = Double.parseDouble(condition.substring(2)); + return numValue <= threshold; + } else if (condition.startsWith("==")) { + double threshold = Double.parseDouble(condition.substring(2)); + return Math.abs(numValue - threshold) < 1e-10; + } + } + + // Handle regex matching + if (condition.startsWith("regex:")) { + String regex = condition.substring(6); + try { + return Pattern.matches(regex, strValue); + } catch (Exception e) { + log.warn("Invalid regex in condition: {}", regex); + } + } + + // Handle contains condition + if (condition.startsWith("contains:")) { + String substring = condition.substring(9); + return strValue.contains(substring); + } + + return false; + } + + /** + * Check if a string can be parsed as a number + */ + private static boolean canParseAsNumber(String str) { + try { + Double.parseDouble(str); + return true; + } catch (NumberFormatException e) { + return false; + } + } + + /** + * Register all built-in processors + */ + private static void registerDefaultProcessors() { + // to String + PROCESSORS.put(TO_STRING, config -> { + boolean escapeJson = Boolean.TRUE.equals(config.getOrDefault("escape_json", false)); + + return inputObj -> { + String text = StringUtils.toJson(inputObj); + if (escapeJson) { + return StringEscapeUtils.escapeJson(text); + } + return text; + }; + }); + + // Regex replacement processor + PROCESSORS.put(REGEX_REPLACE, config -> { + String pattern = (String) config.get("pattern"); + String replacement = (String) config.getOrDefault("replacement", ""); + boolean replaceAll = Boolean.TRUE.equals(config.getOrDefault("replace_all", true)); + + return inputObj -> { + String text = StringUtils.toJson(inputObj); + try { + Pattern p = Pattern.compile(pattern, Pattern.DOTALL); + if (replaceAll) { + return p.matcher(text).replaceAll(replacement); + } else { + return p.matcher(text).replaceFirst(replacement); + } + } catch (Exception e) { + log.warn("Failed to apply regex: {}", e.getMessage()); + return inputObj; + } + }; + }); + + // JsonPath processor + PROCESSORS.put(JSONPATH_FILTER, config -> { + String path = (String) config.get("path"); + Object defaultValue = config.get("default"); + + return input -> { + try { + String jsonStr = StringUtils.toJson(input); + return JsonPath.read(jsonStr, path); + } catch (PathNotFoundException e) { + return defaultValue != null ? defaultValue : input; + } catch (Exception e) { + log.warn("Failed to apply JsonPath: {}", e.getMessage()); + return input; + } + }; + }); + + // Extract JSON processor + PROCESSORS.put(EXTRACT_JSON, config -> { + // Config options + String extractType = (String) config.getOrDefault("extract_type", "auto"); // "object", "array", or "auto" + Object defaultValue = config.get("default"); + + return input -> { + if (!(input instanceof String)) + return input; + String text = (String) input; + + try { + // Find first JSON start char based on config or auto + int start = -1; + + if ("object".equalsIgnoreCase(extractType)) { + start = text.indexOf('{'); + } else if ("array".equalsIgnoreCase(extractType)) { + start = text.indexOf('['); + } else { // auto detect (default) + int startBrace = text.indexOf('{'); + int startBracket = text.indexOf('['); + if (startBrace < 0) {// '{' not found in the string + start = startBracket; + } else if (startBracket < 0) {// '[' not found in the string + start = startBrace; + } else { + start = Math.min(startBrace, startBracket); + } + } + + if (start < 0) { + return defaultValue != null ? defaultValue : input; + } + + ObjectMapper mapper = new ObjectMapper(); + JsonNode jsonNode = mapper.readTree(text.substring(start)); + + if ("object".equalsIgnoreCase(extractType)) { + if (jsonNode.isObject()) { + return mapper.convertValue(jsonNode, Map.class); + } else { + return defaultValue != null ? defaultValue : input; + } + } else if ("array".equalsIgnoreCase(extractType)) { + if (jsonNode.isArray()) { + return mapper.convertValue(jsonNode, List.class); + } else { + return defaultValue != null ? defaultValue : input; + } + } else { // auto + if (jsonNode.isObject()) { + return mapper.convertValue(jsonNode, Map.class); + } else if (jsonNode.isArray()) { + return mapper.convertValue(jsonNode, List.class); + } else { + return defaultValue != null ? defaultValue : input; + } + } + } catch (Exception e) { + log.warn("Failed to extract JSON: {}", e.getMessage()); + return defaultValue != null ? defaultValue : input; + } + }; + }); + + // Regex capture processor + PROCESSORS.put(REGEX_CAPTURE, config -> { + String pattern = (String) config.get("pattern"); + Object groupsObj = config.getOrDefault("groups", "1"); + + // Parse groups into a List + List groupIndices = new ArrayList<>(); + try { + String groupsStr = groupsObj.toString().trim(); + boolean isGroups = groupsStr.startsWith("[") && groupsStr.endsWith("]"); + if (isGroups) { + // Multiple group numbers, example: "[1, 2, 4]" + String[] parts = groupsStr.substring(1, groupsStr.length() - 1).split(","); + for (String part : parts) { + groupIndices.add(Integer.parseInt(part.trim())); + } + } else { + // Single group number + groupIndices.add(Integer.parseInt(groupsStr)); + } + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid 'groups' format: " + groupsObj, e); + } + + return input -> { + String text = StringUtils.toJson(input); + + try { + Pattern p = Pattern.compile(pattern, Pattern.DOTALL); + Matcher m = p.matcher(text); + if (m.find()) { + List captures = new ArrayList<>(); + for (Integer idx : groupIndices) { + if (idx <= m.groupCount()) { + captures.add(m.group(idx)); + } + } + if (captures.size() == 1) { + return captures.get(0); + } + return captures; + // return String.join(" ", captures); // join results with a space + } + return input; + } catch (Exception e) { + log.warn("Failed to apply regex capture: {}", e.getMessage()); + return input; + } + }; + }); + + // Remove JsonPath processor + PROCESSORS.put("remove_jsonpath", config -> { + String path = (String) config.get("path"); + + return input -> { + try { + String jsonStr = StringUtils.toJson(input); + Object document = com.jayway.jsonpath.JsonPath.parse(jsonStr).json(); + // Remove the specified path + com.jayway.jsonpath.JsonPath.parse(document).delete(path); + return document; + } catch (Exception e) { + log.warn("Failed to remove JsonPath {}: {}", path, e.getMessage()); + return input; + } + }; + }); + + PROCESSORS.put("conditional", config -> { + // Get the path to evaluate for all conditions + String path = (String) config.get("path"); + + // Parse routes configuration as a list to preserve order + List routesList = (List) config.get("routes"); + List>> conditionalProcessors = new ArrayList<>(); + + // Parse each route's processors while preserving order + for (Object routeObj : routesList) { + if (routeObj instanceof Map) { + Map routeMap = (Map) routeObj; + for (Map.Entry routeEntry : routeMap.entrySet()) { + List processors = parseProcessorConfigs(routeEntry.getValue()); + conditionalProcessors.add(new AbstractMap.SimpleEntry<>(routeEntry.getKey(), processors)); + } + } + } + + // Parse default processors + List defaultProcessors = config.containsKey("default") + ? parseProcessorConfigs(config.get("default")) + : Collections.emptyList(); + + return input -> { + // Extract the value to check against all conditions + Object valueToCheck = input; + + // If a path is specified, extract the value at that path + if (path != null && !path.isEmpty()) { + try { + String jsonStr = StringUtils.toJson(input); + try { + valueToCheck = JsonPath.read(jsonStr, path); + } catch (PathNotFoundException e) { + valueToCheck = null; + } + } catch (Exception e) { + log.warn("Error evaluating path {}: {}", path, e.getMessage()); + } + } + + // Check each condition in order + for (Map.Entry> entry : conditionalProcessors) { + String condition = entry.getKey(); + if (matchesCondition(condition, valueToCheck)) { + return applyProcessors(input, entry.getValue()); + } + } + + // If no condition matched, use default processors + return applyProcessors(input, defaultProcessors); + }; + }); + + // Add more processors as needed + } + + /** + * Register a custom processor type + * @param type Processor type identifier + * @param factory Factory function to create processor instances + */ + public static void registerProcessor(String type, Function, OutputProcessor> factory) { + PROCESSORS.put(type, factory); + } + + /** + * Create a processor from configuration + * @param type Processor type + * @param config Processor configuration + * @return Configured processor instance + */ + public static OutputProcessor createProcessor(String type, Map config) { + Function, OutputProcessor> factory = PROCESSORS.get(type); + if (factory == null) { + throw new IllegalArgumentException("Unknown output processor type: " + type); + } + return factory.apply(config); + } + + /** + * Create a processing chain from a list of processor configurations + * @param processorConfigs List of processor configurations + * @return List of configured processors + */ + @SuppressWarnings("unchecked") + public static List createProcessingChain(List> processorConfigs) { + if (processorConfigs == null || processorConfigs.isEmpty()) { + return Collections.emptyList(); + } + + List processors = new ArrayList<>(); + for (Map config : processorConfigs) { + String type = (String) config.get("type"); + processors.add(createProcessor(type, config)); + } + + return processors; + } + } + + // List of processors to apply sequentially + private final List processors; + + /** + * Create a processor chain from configuration + * @param processorConfigs List of processor configurations + */ + public ProcessorChain(List> processorConfigs) { + this.processors = ProcessorRegistry.createProcessingChain(processorConfigs); + } + + /** + * Create a processor chain from a list of processor instances + * @param processors List of processor instances + */ + public ProcessorChain(OutputProcessor... processors) { + this.processors = Arrays.asList(processors); + } + + /** + * Process input through the chain of processors + * @param input Input object to process + * @return Processed result + */ + public Object process(Object input) { + Object result = input; + for (OutputProcessor processor : processors) { + result = processor.process(result); + } + return result; + } + + /** + * Check if this chain has any processors + * @return true if the chain has at least one processor + */ + public boolean hasProcessors() { + return !processors.isEmpty(); + } + + /** + * Helper method to extract processor configurations from tool parameters + * @param params Tool parameters + * @return List of processor configurations or empty list if none found + */ + @SuppressWarnings("unchecked") + public static List> extractProcessorConfigs(Map params) { + if (params == null || !params.containsKey(OUTPUT_PROCESSORS)) { + return Collections.emptyList(); + } + + Object configObj = params.get(OUTPUT_PROCESSORS); + if (configObj instanceof List) { + return (List>) configObj; + } + + if (configObj instanceof String) { + String configStr = (String) configObj; + try { + List> processorConfigs = gson.fromJson(configStr, new TypeToken>>() { + }.getType()); + + if (processorConfigs != null) { + return processorConfigs; + } else { + log.warn("Failed to parse output processor config: null result from JSON parsing"); + } + } catch (JsonSyntaxException e) { + log.error("Invalid JSON format in output processor configuration: {}", configStr, e); + } catch (Exception e) { + log.error("Error parsing output processor configuration: {}", configStr, e); + } + } + + return Collections.emptyList(); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java index 425661a0a0..45c56cb979 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java @@ -54,11 +54,11 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.index.IndexSettings; -import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.ml.engine.tools.parser.ToolParser; import org.opensearch.transport.client.Client; import lombok.Getter; @@ -103,7 +103,7 @@ public class ListIndexTool implements Tool { @Setter private Parser inputParser; @Setter - private Parser outputParser; + private Parser outputParser; @SuppressWarnings("unused") private ClusterService clusterService; @@ -111,15 +111,6 @@ public ListIndexTool(Client client, ClusterService clusterService) { this.client = client; this.clusterService = clusterService; - outputParser = new Parser<>() { - @Override - public Object parse(Object o) { - @SuppressWarnings("unchecked") - List mlModelOutputs = (List) o; - return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); - } - }; - this.attributes = new HashMap<>(); attributes.put(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA); attributes.put(STRICT_FIELD, false); @@ -167,8 +158,12 @@ public void run(Map originalParameters, ActionListener li ); } @SuppressWarnings("unchecked") - T response = (T) sb.toString(); - listener.onResponse(response); + T output = (T) sb.toString(); + if (outputParser != null) { + listener.onResponse((T) outputParser.parse(output)); + } else { + listener.onResponse((T) output); + } }, listener::onFailure)); fetchClusterInfoAndPages( @@ -463,8 +458,10 @@ public void init(Client client, ClusterService clusterService) { } @Override - public ListIndexTool create(Map map) { - return new ListIndexTool(client, clusterService); + public ListIndexTool create(Map params) { + ListIndexTool tool = new ListIndexTool(client, clusterService); + tool.setOutputParser(ToolParser.createFromToolParams(params)); + return tool; } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index eab1b34e59..f8b05108eb 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -24,6 +24,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.ml.engine.tools.parser.ToolParser; import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.transport.client.Client; @@ -171,11 +172,17 @@ public void init(Client client) { @Override public MLModelTool create(Map map) { - return new MLModelTool( - client, - (String) map.get(MODEL_ID_FIELD), - (String) map.getOrDefault(RESPONSE_FIELD, DEFAULT_RESPONSE_FIELD) - ); + String modelId = (String) map.get(MODEL_ID_FIELD); + String responseField = (String) map.getOrDefault(RESPONSE_FIELD, DEFAULT_RESPONSE_FIELD); + + // Create the tool with basic configuration + MLModelTool tool = new MLModelTool(client, modelId, responseField); + + // Enhance the output parser with processors if configured + Parser baseParser = tool.getOutputParser(); + tool.setOutputParser(ToolParser.createFromToolParams(map, baseParser)); + + return tool; } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java index 8d99a20a8f..736cdc5a53 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java @@ -30,12 +30,15 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; import org.opensearch.ml.common.transport.model.MLModelSearchAction; import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction; import org.opensearch.ml.common.utils.ToolUtils; +import org.opensearch.ml.engine.tools.parser.ToolParser; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.transport.client.Client; @@ -84,6 +87,10 @@ public class SearchIndexTool implements Tool { private String description = DEFAULT_DESCRIPTION; private Client client; + @Setter + @Getter + @VisibleForTesting + private Parser outputParser; private NamedXContentRegistry xContentRegistry; @@ -215,7 +222,11 @@ public void run(Map originalParameters, ActionListener li tensors.add(ModelTensor.builder().name(name).dataAsMap(convertSearchResponseToMap(r)).build()); outputs.add(ModelTensors.builder().mlModelTensors(tensors).build()); ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(outputs).build(); - listener.onResponse((T) output); + if (outputParser != null) { + listener.onResponse((T) outputParser.parse(output)); + } else { + listener.onResponse((T) output); + } return; } if (hits != null && hits.length > 0) { @@ -225,7 +236,11 @@ public void run(Map originalParameters, ActionListener li String doc = GSON.toJson(docContent); contextBuilder.append(doc).append("\n"); } - listener.onResponse((T) contextBuilder.toString()); + if (outputParser != null) { + listener.onResponse((T) outputParser.parse(contextBuilder.toString())); + } else { + listener.onResponse((T) contextBuilder.toString()); + } } else { listener.onResponse((T) ""); } @@ -281,7 +296,10 @@ public void init(Client client, NamedXContentRegistry xContentRegistry) { @Override public SearchIndexTool create(Map params) { - return new SearchIndexTool(client, xContentRegistry); + SearchIndexTool tool = new SearchIndexTool(client, xContentRegistry); + // Enhance the output parser with processors if configured + tool.setOutputParser(ToolParser.createFromToolParams(params)); + return tool; } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/parser/ToolParser.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/parser/ToolParser.java new file mode 100644 index 0000000000..16fd8a6d2e --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/parser/ToolParser.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools.parser; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.spi.tools.Parser; +import org.opensearch.ml.engine.processor.ProcessorChain; + +/** + * Helper class for tool output processing + */ +public class ToolParser { + + /** + * Create a parser that uses output processors + * @param baseParser Base parser to extract initial result + * @param processorConfigs Processor configurations + * @return Parser with output processing + */ + public static Parser createProcessingParser(Parser baseParser, List> processorConfigs) { + ProcessorChain processorChain = new ProcessorChain(processorConfigs); + + return o -> { + // Apply base parser first + Object baseResult = o; + if (baseParser != null) { + baseResult = baseParser.parse(o); + } + + // Apply output processors if any + if (processorChain.hasProcessors()) { + return processorChain.process(baseResult); + } + + return baseResult; + }; + } + + /** + * Create output parser for a tool from tool parameters + * @param params Tool parameters containing output processor configurations + * @param baseParser Base parser that extracts initial result + * @return Parser with output processing applied + */ + public static Parser createFromToolParams(Map params, Parser baseParser) { + List> processorConfigs = ProcessorChain.extractProcessorConfigs(params); + return createProcessingParser(baseParser, processorConfigs); + } + + public static Parser createFromToolParams(Map params) { + return createFromToolParams(params, null); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java index da87b0f8b5..cecb99f32e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java @@ -491,4 +491,38 @@ public void testRunWithUpdateFailure() { assertNotNull(additionalInfo.get(SECOND_TOOL + ".output")); } + @Test + public void testRunWithReturnDataAsMap() { + final Map params = new HashMap<>(); + Map toolOutput = Map.of("key1", "value1", "key2", "value2"); + MLToolSpec toolSpec = MLToolSpec + .builder() + .name(FIRST_TOOL) + .type(FIRST_TOOL) + .includeOutputInAgentResponse(true) + .parameters(Map.of("return_data_as_map", "true")) + .build(); + final MLAgent mlAgent = MLAgent.builder().name("TestAgent").type(MLAgentType.FLOW.name()).tools(Arrays.asList(toolSpec)).build(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(toolOutput); + return null; + }).when(firstTool).run(anyMap(), any()); + + mlFlowAgentRunner.run(mlAgent, params, agentActionListener); + Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedValue = objectCaptor.getValue(); + + if (capturedValue instanceof List) { + List agentOutput = (List) capturedValue; + assertEquals(1, agentOutput.size()); + assertEquals(FIRST_TOOL + ".output", agentOutput.get(0).getName()); + assertEquals(toolOutput, agentOutput.get(0).getDataAsMap()); + } else if (capturedValue instanceof Map) { + Map agentOutput = (Map) capturedValue; + assertEquals(toolOutput, agentOutput); + } + } + } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index e026fce36c..adea90bc96 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -423,17 +423,69 @@ public void testEscapeRemoteInferenceInputData_WithNoEscapeParams() { params.put("key1", inputKey1); params.put("key2", "test value"); params.put("key3", inputKey3); - params.put("NO_ESCAPE_PARAMS", "key1,key3"); + params.put("no_escape_params", "key1,key3"); RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build(); ConnectorUtils.escapeRemoteInferenceInputData(inputData); - String expectedKey1 = "hello \\\"world\\\""; - String expectedKey3 = "special \\\"chars\\\""; - assertEquals(expectedKey1, inputData.getParameters().get("key1")); + assertEquals(inputKey1, inputData.getParameters().get("key1")); assertEquals("test value", inputData.getParameters().get("key2")); - assertEquals(expectedKey3, inputData.getParameters().get("key3")); + assertEquals(inputKey3, inputData.getParameters().get("key3")); + } + + @Test + public void testEscapeRemoteInferenceInputData_NullValue() { + Map params = new HashMap<>(); + params.put("key1", null); + RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build(); + + ConnectorUtils.escapeRemoteInferenceInputData(inputData); + + assertNull(inputData.getParameters().get("key1")); + } + + @Test + public void testEscapeRemoteInferenceInputData_JsonValue() { + Map params = new HashMap<>(); + params.put("key1", "{\"test\": \"value\"}"); + RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build(); + + ConnectorUtils.escapeRemoteInferenceInputData(inputData); + + assertEquals("{\"test\": \"value\"}", inputData.getParameters().get("key1")); + } + + @Test + public void testEscapeRemoteInferenceInputData_EscapeValue() { + Map params = new HashMap<>(); + params.put("key1", "test\"value"); + RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build(); + + ConnectorUtils.escapeRemoteInferenceInputData(inputData); + + assertEquals("test\\\"value", inputData.getParameters().get("key1")); + } + + @Test + public void testEscapeRemoteInferenceInputData_NoEscapeParam() { + Map params = new HashMap<>(); + params.put("key1", "test\"value"); + params.put("no_escape_params", "key1"); + RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build(); + + ConnectorUtils.escapeRemoteInferenceInputData(inputData); + + assertEquals("test\"value", inputData.getParameters().get("key1")); + } + + @Test + public void testEscapeRemoteInferenceInputData_NullParameters() { + RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(null).build(); + + ConnectorUtils.escapeRemoteInferenceInputData(inputData); + + assertNull(inputData.getParameters()); } @Test @@ -459,4 +511,137 @@ public void buildSdkRequest_InvalidEndpoint_ThrowException() { .build(); ConnectorUtils.buildSdkRequest("PREDICT", connector, Collections.emptyMap(), "{}", software.amazon.awssdk.http.SdkHttpMethod.POST); } + + @Test + public void processOutput_WithProcessorChain() throws IOException { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map parameters = new HashMap<>(); + parameters.put("processor_configs", "[{\"type\":\"test_processor\"}]"); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + String modelResponse = "{\"result\":\"test response\"}"; + + ModelTensors tensors = ConnectorUtils.processOutput(PREDICT.name(), modelResponse, connector, scriptService, parameters, null); + + assertEquals(1, tensors.getMlModelTensors().size()); + assertEquals("response", tensors.getMlModelTensors().get(0).getName()); + } + + @Test + public void processOutput_WithProcessorChainAndResponseFilter() throws IOException { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map parameters = new HashMap<>(); + parameters.put("processor_configs", "[{\"type\":\"test_processor\"}]"); + parameters.put("response_filter", "$.data"); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + String modelResponse = "{\"data\":{\"result\":\"filtered response\"}}"; + + ModelTensors tensors = ConnectorUtils.processOutput(PREDICT.name(), modelResponse, connector, scriptService, parameters, null); + + assertEquals(1, tensors.getMlModelTensors().size()); + assertEquals("response", tensors.getMlModelTensors().get(0).getName()); + } + + @Test + public void processOutput_WithResponseFilterOnly() throws IOException { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map parameters = new HashMap<>(); + parameters.put("response_filter", "$.data"); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + String modelResponse = "{\"data\":{\"result\":\"filtered response\"}}"; + + ModelTensors tensors = ConnectorUtils.processOutput(PREDICT.name(), modelResponse, connector, scriptService, parameters, null); + + assertEquals(1, tensors.getMlModelTensors().size()); + assertEquals("response", tensors.getMlModelTensors().get(0).getName()); + } + + @Test + public void processOutput_ScriptReturnModelTensor_WithJsonResponse() throws IOException { + String postprocessResult = "{\"name\":\"test\",\"data\":[1,2,3]}"; + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult)); + + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .postProcessFunction("custom_script") + .build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + String modelResponse = "{\"result\":\"test\"}"; + + ModelTensors tensors = ConnectorUtils + .processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null); + + assertEquals(1, tensors.getMlModelTensors().size()); + } + + @Test + public void processOutput_WithProcessorChain_StringOutput() throws IOException { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map parameters = new HashMap<>(); + parameters.put("processor_configs", "[{\"type\":\"test_processor\"}]"); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); + String modelResponse = "{\"result\":\"test response\"}"; + + ModelTensors tensors = ConnectorUtils.processOutput(PREDICT.name(), modelResponse, connector, scriptService, parameters, null); + + assertEquals(1, tensors.getMlModelTensors().size()); + assertEquals("response", tensors.getMlModelTensors().get(0).getName()); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/processor/ProcessorChainTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/processor/ProcessorChainTests.java new file mode 100644 index 0000000000..622bfd75b0 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/processor/ProcessorChainTests.java @@ -0,0 +1,1003 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.processor; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.opensearch.ml.engine.processor.ProcessorChain.EXTRACT_JSON; +import static org.opensearch.ml.engine.processor.ProcessorChain.JSONPATH_FILTER; +import static org.opensearch.ml.engine.processor.ProcessorChain.REGEX_CAPTURE; +import static org.opensearch.ml.engine.processor.ProcessorChain.REGEX_REPLACE; +import static org.opensearch.ml.engine.processor.ProcessorChain.TO_STRING; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.processor.ProcessorChain.OutputProcessor; +import org.opensearch.ml.engine.processor.ProcessorChain.ProcessorRegistry; + +public class ProcessorChainTests { + + @Test + public void testToString() { + // First test with replace_all=true + Map configMap = new HashMap<>(); + + OutputProcessor processorReplaceAll = ProcessorRegistry.createProcessor(TO_STRING, configMap); + String result = (String) processorReplaceAll.process(Map.of("key1", "value1")); + assertEquals("{\"key1\":\"value1\"}", result); + + result = (String) processorReplaceAll.process(List.of("value1", "value2")); + assertEquals("[\"value1\",\"value2\"]", result); + } + + @Test + public void testToString_ModelTensor() { + Map configMap = new HashMap<>(); + + OutputProcessor processorReplaceAll = ProcessorRegistry.createProcessor(TO_STRING, configMap); + ModelTensor modelTensor = ModelTensor.builder().name("test").dataAsMap(Map.of("key1", "value1")).build(); + String result = (String) processorReplaceAll.process(modelTensor); + assertEquals("{\"name\":\"test\",\"dataAsMap\":{\"key1\":\"value1\"}}", result); + + result = (String) processorReplaceAll.process(Collections.singletonList(modelTensor)); + assertEquals("[{\"name\":\"test\",\"dataAsMap\":{\"key1\":\"value1\"}}]", result); + } + + @Test + public void testToString_ModelTensors() { + Map configMap = new HashMap<>(); + + OutputProcessor processorReplaceAll = ProcessorRegistry.createProcessor(TO_STRING, configMap); + ModelTensor modelTensor = ModelTensor.builder().name("test").dataAsMap(Map.of("key1", "value1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Collections.singletonList(modelTensor)).build(); + String result = (String) processorReplaceAll.process(modelTensors); + assertEquals("{\"output\":[{\"name\":\"test\",\"dataAsMap\":{\"key1\":\"value1\"}}]}", result); + } + + @Test + public void testToString_ModelTensorOutput() { + Map configMap = new HashMap<>(); + + OutputProcessor processorReplaceAll = ProcessorRegistry.createProcessor(TO_STRING, configMap); + ModelTensor modelTensor = ModelTensor.builder().name("test").dataAsMap(Map.of("key1", "value1")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Collections.singletonList(modelTensor)).build(); + ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Collections.singletonList(modelTensors)).build(); + String result = (String) processorReplaceAll.process(modelTensorOutput); + assertEquals("{\"inference_results\":[{\"output\":[{\"name\":\"test\",\"dataAsMap\":{\"key1\":\"value1\"}}]}]}", result); + } + + @Test + public void testToString_EscapeJson() { + Map configMap = new HashMap<>(); + configMap.put("escape_json", true); + + OutputProcessor processorReplaceAll = ProcessorRegistry.createProcessor(TO_STRING, configMap); + + String result = (String) processorReplaceAll.process("hello \"world\" opensearch"); + assertEquals("hello \\\"world\\\" opensearch", result); + } + + @Test + public void testRegexProcessor() { + // First test with replace_all=true + Map configReplace = new HashMap<>(); + configReplace.put("pattern", "test(\\d+)"); + configReplace.put("replacement", "replaced$1"); + configReplace.put("replace_all", true); + + OutputProcessor processorReplaceAll = ProcessorRegistry.createProcessor(REGEX_REPLACE, configReplace); + String resultReplaceAll = (String) processorReplaceAll.process("test123 test456"); + assertEquals("replaced123 replaced456", resultReplaceAll); + + // Second test with replace_all=false - using a completely fresh config + configReplace.put("replace_all", false); + + OutputProcessor processorReplaceFirst = ProcessorRegistry.createProcessor(REGEX_REPLACE, configReplace); + String resultReplaceFirst = (String) processorReplaceFirst.process("test123 test456"); + assertEquals("replaced123 test456", resultReplaceFirst); + } + + @Test + public void testRegexReplaceProcessorMultipleGroups() { + // The regex matches parts like "test123 abcDEF" + Map config = new HashMap<>(); + // Replacement uses $1-$2-$3 to combine captured groups with - + config.put("pattern", "test(\\d+) (abc)(\\w+)"); + config.put("replacement", "replaced$1-$2-$3"); + config.put("replace_all", true); + + OutputProcessor processor = ProcessorRegistry.createProcessor(REGEX_REPLACE, config); + + String input = "test123 abcDEF test456 abcXYZ"; + Object result = processor.process(input); + + // Expected to replace all matches, combining groups with '-' separator + String expected = "replaced123-abc-DEF replaced456-abc-XYZ"; + + assertEquals(expected, result); + } + + @Test + public void testRegexProcessorWithNonStringInput() { + Map config = new HashMap<>(); + config.put("pattern", "\"key\"\\s*:\\s*\"(.+?)\""); + config.put("replacement", "\"key\":\"modified-$1\""); + + OutputProcessor processor = ProcessorRegistry.createProcessor(REGEX_REPLACE, config); + + // Test with Map input (should be converted to JSON string) + Map mapInput = new HashMap<>(); + mapInput.put("key", "value"); + + String result = (String) processor.process(mapInput); + assertTrue(result.contains("modified-value")); + } + + @Test + public void testJsonPathProcessor() { + Map config = new HashMap<>(); + config.put("path", "$.person.name"); + + OutputProcessor processor = ProcessorRegistry.createProcessor(JSONPATH_FILTER, config); + + String input = "{\"person\": {\"name\": \"John\", \"age\": 30}}"; + Object result = processor.process(input); + assertEquals("John", result); + + // Test with default value for missing path + config.put("default", "Default Name"); + config.put("path", "$.person.missing"); + processor = ProcessorRegistry.createProcessor(JSONPATH_FILTER, config); + result = processor.process(input); + assertEquals("Default Name", result); + } + + @Test + public void testJsonPathProcessorWithError() { + Map config = new HashMap<>(); + config.put("path", "$.invalid..path"); // Invalid path syntax + + OutputProcessor processor = ProcessorRegistry.createProcessor(JSONPATH_FILTER, config); + + String input = "{\"person\": {\"name\": \"John\"}}"; + Object result = processor.process(input); + // Should return original input when error occurs + assertEquals(input, result); + } + + @Test + public void testExtractJsonProcessor() { + Map config = new HashMap<>(); + + OutputProcessor processor = ProcessorRegistry.createProcessor(EXTRACT_JSON, config); + + // Test with JSON embedded in text + String input = "Some text before {\"key\": \"value\"} and after"; + Object result = processor.process(input); + assertTrue(result instanceof Map); + assertEquals("value", ((Map) result).get("key")); + + // Test with non-JSON input + result = processor.process("No JSON here"); + assertEquals("No JSON here", result); + } + + @Test + public void testExtractJsonProcessorArray() { + Map config = new HashMap<>(); + + OutputProcessor processor = ProcessorRegistry.createProcessor(EXTRACT_JSON, config); + + // Test with JSON array embedded in text + String input = "Some text before [{\"key1\": \"value1\"}, {\"key2\": \"value2\"}] and after"; + Object result = processor.process(input); + + assertTrue(result instanceof List); + + @SuppressWarnings("unchecked") + List> list = (List>) result; + + assertEquals(2, list.size()); + assertEquals("value1", list.get(0).get("key1")); + assertEquals("value2", list.get(1).get("key2")); + + // Test with non-JSON input returns unchanged + result = processor.process("No JSON array here"); + assertEquals("No JSON array here", result); + } + + @Test + public void testExtractJsonProcessorWithInvalidJson() { + Map config = new HashMap<>(); + OutputProcessor processor = ProcessorRegistry.createProcessor(EXTRACT_JSON, config); + + // Invalid JSON that starts with a brace + String input = "{not valid json}"; + Object result = processor.process(input); + // Should return original input on error + assertEquals(input, result); + } + + @Test + public void testExtractJsonProcessorWithExtractTypeObject() { + Map config = new HashMap<>(); + config.put("extract_type", "object"); + + OutputProcessor processor = ProcessorRegistry.createProcessor(EXTRACT_JSON, config); + + String input = "prefix {\"foo\":\"bar\"} suffix"; + Object result = processor.process(input); + + assertTrue(result instanceof Map); + assertEquals("bar", ((Map) result).get("foo")); + + // First item of JSON array will be extracted when forcing object type + input = "prefix [{\"foo\":\"bar\"}] suffix"; + result = processor.process(input); + assertTrue(result instanceof Map); + assertEquals("bar", ((Map) result).get("foo")); + } + + @Test + public void testExtractJsonProcessorWithExtractTypeArray() { + Map config = new HashMap<>(); + config.put("extract_type", "array"); + + OutputProcessor processor = ProcessorRegistry.createProcessor(EXTRACT_JSON, config); + + String input = "prefix [{\"foo\":\"bar\"}, {\"baz\":\"qux\"}] suffix"; + Object result = processor.process(input); + + assertTrue(result instanceof List); + + @SuppressWarnings("unchecked") + List> list = (List>) result; + assertEquals(2, list.size()); + assertEquals("bar", list.get(0).get("foo")); + assertEquals("qux", list.get(1).get("baz")); + + // JSON object should NOT be extracted when forcing array type, fallback to input + input = "prefix {\"foo\":\"bar\"} suffix"; + result = processor.process(input); + assertEquals(input, result); + } + + @Test + public void testExtractJsonProcessorWithDefaultValue() { + Map config = new HashMap<>(); + config.put("extract_type", "array"); + List defaultVal = List.of("default"); + config.put("default", defaultVal); + + OutputProcessor processor = ProcessorRegistry.createProcessor(EXTRACT_JSON, config); + + // No JSON array found, should return default value + String input = "no json array here"; + Object result = processor.process(input); + assertSame(defaultVal, result); + + // Invalid JSON should also return default + input = "[invalid json]"; + result = processor.process(input); + assertSame(defaultVal, result); + } + + @Test + public void testExtractJsonProcessorWithNoJsonStart() { + Map config = new HashMap<>(); + + OutputProcessor processor = ProcessorRegistry.createProcessor(EXTRACT_JSON, config); + + // Input string with no '{' or '[' + String input = "no braces or brackets here"; + Object result = processor.process(input); + assertEquals(input, result); + } + + @Test + public void testExtractJsonProcessorWithNonStringInput() { + Map config = new HashMap<>(); + + OutputProcessor processor = ProcessorRegistry.createProcessor(EXTRACT_JSON, config); + + // Pass in non-string input (e.g. Integer) + Integer input = 12345; + Object result = processor.process(input); + assertSame(input, result); + } + + @Test + public void testRegexCaptureProcessor() { + Map config = new HashMap<>(); + config.put("pattern", "value: (\\d+)"); + config.put("groups", 1); + + OutputProcessor processor = ProcessorRegistry.createProcessor(REGEX_CAPTURE, config); + + // Test successful capture + String input = "The value: 123 is captured"; + Object result = processor.process(input); + assertEquals("123", result); + + // Test no match + result = processor.process("No match here"); + assertEquals("No match here", result); + + config.put("groups", "[1]"); + + result = processor.process(input); + assertEquals("123", result); + } + + @Test + public void testRegexCaptureProcessor_MultipleGroups() { + Map config = new HashMap<>(); + config.put("pattern", "value: (\\d+), name: (\\w+), status: (\\w+)"); + config.put("groups", "[1, 3]"); // multiple groups + + OutputProcessor processor = ProcessorRegistry.createProcessor(REGEX_CAPTURE, config); + + // Input string with all three groups + String input = "value: 123, name: Alice, status: active"; + + Object result = processor.process(input); + + // Expect a List with three captured groups + assertTrue(result instanceof List); + + @SuppressWarnings("unchecked") + List capturedGroups = (List) result; + + assertEquals(2, capturedGroups.size()); + assertEquals("123", capturedGroups.get(0)); + assertEquals("active", capturedGroups.get(1)); + + // Test with a single group (should return String, not List) + config.put("groups", "2"); + processor = ProcessorRegistry.createProcessor(REGEX_CAPTURE, config); + result = processor.process(input); + assertTrue(result instanceof String); + assertEquals("Alice", result); + + // Test no match returns original input + result = processor.process("no matching text here"); + assertEquals("no matching text here", result); + } + + @Test + public void testRegexCaptureProcessorWithInvalidPattern() { + Map config = new HashMap<>(); + config.put("pattern", "(unclosed"); // Invalid regex pattern + config.put("group", 1); + + OutputProcessor processor = ProcessorRegistry.createProcessor(REGEX_CAPTURE, config); + + String input = "test input"; + Object result = processor.process(input); + // Should return original input on error + assertEquals(input, result); + } + + @Test + public void testSimpleProcessorChain() { + // Create a chain of processors + List> configs = new ArrayList<>(); + + Map regexConfig = new HashMap<>(); + regexConfig.put("type", REGEX_REPLACE); + regexConfig.put("pattern", ".*?"); + regexConfig.put("replacement", ""); + configs.add(regexConfig); + + Map extractConfig = new HashMap<>(); + extractConfig.put("type", EXTRACT_JSON); + configs.add(extractConfig); + + ProcessorChain chain = new ProcessorChain(configs); + + // Test the chain + String input = "This is reasoning{\"query\":{\"match_all\":{}}}"; + Object result = chain.process(input); + + assertTrue(result instanceof Map); + assertNotNull(((Map) result).get("query")); + } + + @Test + public void testComplexProcessorChain() { + // Create a complex chain that processes in sequence + OutputProcessor first = input -> ((String) input).replace("first", "1st"); + OutputProcessor second = input -> ((String) input).replace("second", "2nd"); + OutputProcessor third = input -> ((String) input).replace("third", "3rd"); + + ProcessorChain chain = new ProcessorChain(first, second, third); + + String input = "first second third"; + Object result = chain.process(input); + + assertEquals("1st 2nd 3rd", result); + } + + @Test + public void testEmptyChain() { + ProcessorChain chain = new ProcessorChain(Collections.emptyList()); + assertFalse(chain.hasProcessors()); + + String input = "test"; + Object result = chain.process(input); + assertEquals(input, result); + } + + @Test + public void testExtractProcessorConfigsWithList() { + // Test with List input + Map params = new HashMap<>(); + List> configs = new ArrayList<>(); + + Map config1 = new HashMap<>(); + config1.put("type", REGEX_REPLACE); + config1.put("pattern", "test"); + configs.add(config1); + + params.put(ProcessorChain.OUTPUT_PROCESSORS, configs); + + List> result = ProcessorChain.extractProcessorConfigs(params); + assertEquals(1, result.size()); + assertEquals(REGEX_REPLACE, result.get(0).get("type")); + } + + @Test + public void testExtractProcessorConfigsWithString() { + // Test with String input + Map params = new HashMap<>(); + String configStr = "[{\"type\":\"regex_replace\",\"pattern\":\"test\",\"replacement\":\"\"}]"; + + params.put(ProcessorChain.OUTPUT_PROCESSORS, configStr); + + List> result = ProcessorChain.extractProcessorConfigs(params); + assertEquals(1, result.size()); + assertEquals(REGEX_REPLACE, result.get(0).get("type")); + } + + @Test + public void testExtractProcessorConfigsWithInvalidString() { + // Test with invalid String input + Map params = new HashMap<>(); + String configStr = "not a json"; + + params.put(ProcessorChain.OUTPUT_PROCESSORS, configStr); + + List> result = ProcessorChain.extractProcessorConfigs(params); + assertTrue(result.isEmpty()); + } + + @Test + public void testExtractProcessorConfigsWithNull() { + // Test with null params + List> result = ProcessorChain.extractProcessorConfigs(null); + assertTrue(result.isEmpty()); + + // Test with empty params + result = ProcessorChain.extractProcessorConfigs(Collections.emptyMap()); + assertTrue(result.isEmpty()); + } + + @Test + public void testExtractProcessorConfigsWithEscapedJson() { + // Test with escaped JSON string (common in configuration) + Map params = new HashMap<>(); + String configStr = + "[{\"pattern\":\"\\u003creasoning\\u003e.*?\\u003c/reasoning\\u003e\",\"type\":\"regex_replace\",\"replacement\":\"\"},{\"type\":\"extract_json\"}]"; + + params.put(ProcessorChain.OUTPUT_PROCESSORS, configStr); + + List> result = ProcessorChain.extractProcessorConfigs(params); + assertEquals(2, result.size()); + assertEquals(REGEX_REPLACE, result.get(0).get("type")); + assertEquals(EXTRACT_JSON, result.get(1).get("type")); + assertEquals(".*?", result.get(0).get("pattern").toString().replace("\\", "")); + } + + @Test + public void testRegisterCustomProcessor() { + // Register a custom processor + ProcessorRegistry.registerProcessor("custom", config -> { + String prefix = (String) config.getOrDefault("prefix", "custom"); + return input -> prefix + ": " + input; + }); + + Map config = new HashMap<>(); + config.put("prefix", "PREFIX"); + + OutputProcessor processor = ProcessorRegistry.createProcessor("custom", config); + assertEquals("PREFIX: test", processor.process("test")); + } + + @Test(expected = IllegalArgumentException.class) + public void testCreateProcessorWithInvalidType() { + ProcessorRegistry.createProcessor("invalid_type", Collections.emptyMap()); + } + + @Test + public void testRealWorldScenario() { + // This test simulates the real-world case from the issue + String input = + "We need count of flights from Seattle. Index mapping shows Origin field is string keyword. Need to filter where Origin contains Seattle? In sample mapping: origin values are names like \"Frankfurt am Main Airport\", \"Cape Town International Airport\". Probably Seattle? but not in sample. Anyway query: term or match? Use match? Could use term exact. Use keyword field. Probably `Origin:\"Seattle\"`. But location might be \"Seattle-Tacoma International Airport\". Use match? They ask \"total flights from Seattle\". Likely match on Origin contains \"Seattle\". Use match query with \"Seattle\". Then size 0 (only count). Add track_total_hits: true maybe. So produce query.{\"query\":{\"match\":{\"Origin\":\"Seattle\"}},\"size\":0,\"track_total_hits\":true}"; + + // Create processors for the exact case in question + List> configs = new ArrayList<>(); + + Map regexConfig = new HashMap<>(); + regexConfig.put("type", REGEX_REPLACE); + regexConfig.put("pattern", ".*?"); + regexConfig.put("replacement", ""); + configs.add(regexConfig); + + Map extractConfig = new HashMap<>(); + extractConfig.put("type", EXTRACT_JSON); + configs.add(extractConfig); + + ProcessorChain chain = new ProcessorChain(configs); + + Object result = chain.process(input); + + assertTrue(result instanceof Map); + Map queryResult = (Map) result; + + // Verify the query parts are extracted correctly + assertTrue(queryResult.containsKey("query")); + assertTrue(queryResult.containsKey("size")); + assertEquals(0, queryResult.get("size")); + assertTrue(queryResult.containsKey("track_total_hits")); + assertEquals(true, queryResult.get("track_total_hits")); + + Map queryMap = (Map) queryResult.get("query"); + assertTrue(queryMap.containsKey("match")); + + Map matchMap = (Map) queryMap.get("match"); + assertEquals("Seattle", matchMap.get("Origin")); + } + + @Test + public void testBasicStringConditions() { + // Setup test input + Map input = new HashMap<>(); + input.put("status", "success"); + + // Create processor configs + List> processorConfigs = new ArrayList<>(); + + Map conditionalConfig = new HashMap<>(); + conditionalConfig.put("type", "conditional"); + conditionalConfig.put("path", "$.status"); + + // ====== Ordered routes list ====== + List routes = new ArrayList<>(); + + // Success route + List> successRoute = new ArrayList<>(); + Map successToString = new HashMap<>(); + successToString.put("type", "to_string"); + successRoute.add(successToString); + + Map successRegexReplace = new HashMap<>(); + successRegexReplace.put("type", "regex_replace"); + successRegexReplace.put("pattern", "\\{.*\\}"); + successRegexReplace.put("replacement", "Operation was successful"); + successRoute.add(successRegexReplace); + + routes.add(Collections.singletonMap("success", successRoute)); + + // Error route + List> errorRoute = new ArrayList<>(); + Map errorToString = new HashMap<>(); + errorToString.put("type", "to_string"); + errorRoute.add(errorToString); + + Map errorRegexReplace = new HashMap<>(); + errorRegexReplace.put("type", "regex_replace"); + errorRegexReplace.put("pattern", "\\{.*\\}"); + errorRegexReplace.put("replacement", "Operation failed"); + errorRoute.add(errorRegexReplace); + + routes.add(Collections.singletonMap("error", errorRoute)); + + // Put ordered routes into config + conditionalConfig.put("routes", routes); + + processorConfigs.add(conditionalConfig); + + // Create and run processor chain + ProcessorChain chain = new ProcessorChain(processorConfigs); + + // Test success + Object result = chain.process(input); + assertEquals("Operation was successful", result); + + // Test error + input.put("status", "error"); + result = chain.process(input); + assertEquals("Operation failed", result); + } + + @Test + public void testNumericConditions() { + Map input = new HashMap<>(); + input.put("count", 42); + + Map conditionalConfig = new HashMap<>(); + conditionalConfig.put("type", "conditional"); + conditionalConfig.put("path", "$.count"); + + List> routes = new ArrayList<>(); + + // >50 + List> gtRoute = new ArrayList<>(); + Map gtString = new HashMap<>(); + gtString.put("type", "to_string"); + gtRoute.add(gtString); + Map gtReplace = new HashMap<>(); + gtReplace.put("type", "regex_replace"); + gtReplace.put("pattern", ".*"); + gtReplace.put("replacement", "Greater than 50"); + gtRoute.add(gtReplace); + routes.add(Map.of(">50", gtRoute)); + + // ==42 + List> eqRoute = new ArrayList<>(); + Map eqString = new HashMap<>(); + eqString.put("type", "to_string"); + eqRoute.add(eqString); + Map eqReplace = new HashMap<>(); + eqReplace.put("type", "regex_replace"); + eqReplace.put("pattern", "^.*$"); + eqReplace.put("replacement", "Exactly 42"); + eqRoute.add(eqReplace); + routes.add(Map.of("==42", eqRoute)); + + // <50 + List> ltRoute = new ArrayList<>(); + Map ltString = new HashMap<>(); + ltString.put("type", "to_string"); + ltRoute.add(ltString); + Map ltReplace = new HashMap<>(); + ltReplace.put("type", "regex_replace"); + ltReplace.put("pattern", "^.*$"); + ltReplace.put("replacement", "Less than 50"); + ltRoute.add(ltReplace); + routes.add(Map.of("<50", ltRoute)); + + conditionalConfig.put("routes", routes); + + // Default + List> defaultRoute = new ArrayList<>(); + Map defaultString = new HashMap<>(); + defaultString.put("type", "to_string"); + defaultRoute.add(defaultString); + Map defaultReplace = new HashMap<>(); + defaultReplace.put("type", "regex_replace"); + defaultReplace.put("pattern", "^.*$"); + defaultReplace.put("replacement", "Default route"); + defaultRoute.add(defaultReplace); + conditionalConfig.put("default", defaultRoute); + + OutputProcessor processor = ProcessorChain.ProcessorRegistry.createProcessor("conditional", conditionalConfig); + assertEquals("Exactly 42", processor.process(input)); + + input.put("count", 30); + assertEquals("Less than 50", processor.process(input)); + + input.put("count", 50); + assertEquals("Default route", processor.process(input)); + } + + @Test + public void testExistenceConditions() { + Map input = new HashMap<>(); + input.put("required", "value"); + input.put("optional", null); + + Map conditionalConfig = new HashMap<>(); + conditionalConfig.put("type", "conditional"); + conditionalConfig.put("path", "$.missing"); + + List> routes = new ArrayList<>(); + + // exists + List> existsRoute = new ArrayList<>(); + Map existsReplace = new HashMap<>(); + existsReplace.put("type", "regex_replace"); + existsReplace.put("pattern", "\\{.*\\}"); + existsReplace.put("replacement", "Field exists"); + existsRoute.add(existsReplace); + routes.add(Map.of("exists", existsRoute)); + + // not_exists + List> notExistsRoute = new ArrayList<>(); + Map notExistsReplace = new HashMap<>(); + notExistsReplace.put("type", "regex_replace"); + notExistsReplace.put("pattern", "\\{.*\\}"); + notExistsReplace.put("replacement", "Field does not exist"); + notExistsRoute.add(notExistsReplace); + routes.add(Map.of("not_exists", notExistsRoute)); + + conditionalConfig.put("routes", routes); + + OutputProcessor processor = ProcessorChain.ProcessorRegistry.createProcessor("conditional", conditionalConfig); + assertEquals("Field does not exist", processor.process(input)); + + conditionalConfig.put("path", "$.required"); + processor = ProcessorChain.ProcessorRegistry.createProcessor("conditional", conditionalConfig); + assertEquals("Field exists", processor.process(input)); + + conditionalConfig.put("path", "$.optional"); + processor = ProcessorChain.ProcessorRegistry.createProcessor("conditional", conditionalConfig); + assertEquals("Field does not exist", processor.process(input)); + } + + @Test + public void testNoMatchingConditionUsesDefault() { + Map input = new HashMap<>(); + input.put("status", "unknown"); + + Map conditionalConfig = new HashMap<>(); + conditionalConfig.put("type", "conditional"); + conditionalConfig.put("path", "$.status"); + + List> routes = new ArrayList<>(); + + // success + List> successRoute = new ArrayList<>(); + Map successReplace = new HashMap<>(); + successReplace.put("type", "regex_replace"); + successReplace.put("pattern", "^.*$"); + successReplace.put("replacement", "Success route"); + successRoute.add(successReplace); + routes.add(Map.of("success", successRoute)); + + // error + List> errorRoute = new ArrayList<>(); + Map errorReplace = new HashMap<>(); + errorReplace.put("type", "regex_replace"); + errorReplace.put("pattern", "^.*$"); + errorReplace.put("replacement", "Error route"); + errorRoute.add(errorReplace); + routes.add(Map.of("error", errorRoute)); + + conditionalConfig.put("routes", routes); + + // default + List> defaultRoute = new ArrayList<>(); + Map defaultReplace = new HashMap<>(); + defaultReplace.put("type", "regex_replace"); + defaultReplace.put("pattern", "^.*$"); + defaultReplace.put("replacement", "Default route"); + defaultRoute.add(defaultReplace); + conditionalConfig.put("default", defaultRoute); + + OutputProcessor processor = ProcessorChain.ProcessorRegistry.createProcessor("conditional", conditionalConfig); + assertEquals("Default route", processor.process(input)); + } + + @Test + public void testChainedProcessors() { + Map input = new HashMap<>(); + input.put("status", "SUCCESS"); + + List> successChain = new ArrayList<>(); + Map step1 = new HashMap<>(); + step1.put("type", "to_string"); + successChain.add(step1); + Map step1Replace = new HashMap<>(); + step1Replace.put("type", "regex_replace"); + step1Replace.put("pattern", "^.*$"); + step1Replace.put("replacement", "Step 1"); + successChain.add(step1Replace); + Map step2Replace = new HashMap<>(); + step2Replace.put("type", "regex_replace"); + step2Replace.put("pattern", "^.*$"); + step2Replace.put("replacement", "Step 2"); + successChain.add(step2Replace); + + Map conditionalConfig = new HashMap<>(); + conditionalConfig.put("type", "conditional"); + conditionalConfig.put("path", "$.status"); + + List> routes = new ArrayList<>(); + routes.add(Map.of("SUCCESS", successChain)); + + List> errorRoute = new ArrayList<>(); + Map errorReplace = new HashMap<>(); + errorReplace.put("type", "regex_replace"); + errorReplace.put("pattern", "^.*$"); + errorReplace.put("replacement", "Error occurred"); + errorRoute.add(errorReplace); + routes.add(Map.of("ERROR", errorRoute)); + + conditionalConfig.put("routes", routes); + + OutputProcessor processor = ProcessorChain.ProcessorRegistry.createProcessor("conditional", conditionalConfig); + assertEquals("Step 2", processor.process(input)); + } + + @Test + public void testNoPathSpecified() { + Map input = new HashMap<>(); + input.put("value", 42); + + Map conditionalConfig = new HashMap<>(); + conditionalConfig.put("type", "conditional"); + + List> routes = new ArrayList<>(); + List> existsRoute = new ArrayList<>(); + Map existsReplace = new HashMap<>(); + existsReplace.put("type", "regex_replace"); + existsReplace.put("pattern", "^.*$"); + existsReplace.put("replacement", "Input exists"); + existsRoute.add(existsReplace); + routes.add(Map.of("exists", existsRoute)); + + conditionalConfig.put("routes", routes); + + OutputProcessor processor = ProcessorChain.ProcessorRegistry.createProcessor("conditional", conditionalConfig); + assertEquals("Input exists", processor.process(input)); + + assertNull(processor.process(null)); + } + + @Test + public void testProcessorInProcessorChain() { + Map input = new HashMap<>(); + input.put("value", 100); + + List> chainConfig = new ArrayList<>(); + + Map extractConfig = new HashMap<>(); + extractConfig.put("type", "jsonpath_filter"); + extractConfig.put("path", "$.value"); + chainConfig.add(extractConfig); + + Map conditionalConfig = new HashMap<>(); + conditionalConfig.put("type", "conditional"); + + List> routes = new ArrayList<>(); + List> gtRoute = new ArrayList<>(); + Map gtReplace = new HashMap<>(); + gtReplace.put("type", "regex_replace"); + gtReplace.put("pattern", "^.*$"); + gtReplace.put("replacement", "Greater than 50"); + gtRoute.add(gtReplace); + routes.add(Map.of(">50", gtRoute)); + + List> lteRoute = new ArrayList<>(); + Map lteReplace = new HashMap<>(); + lteReplace.put("type", "regex_replace"); + lteReplace.put("pattern", "^.*$"); + lteReplace.put("replacement", "Less than or equal to 50"); + lteRoute.add(lteReplace); + routes.add(Map.of("<=50", lteRoute)); + + conditionalConfig.put("routes", routes); + + chainConfig.add(conditionalConfig); + + ProcessorChain chain = new ProcessorChain(chainConfig); + assertEquals("Greater than 50", chain.process(input)); + } + + private ProcessorChain.OutputProcessor createRemoveJsonPathProcessor(String path) { + Map config = new HashMap<>(); + config.put("type", "remove_jsonpath"); + config.put("path", path); + return ProcessorChain.ProcessorRegistry.createProcessor("remove_jsonpath", config); + } + + @Test + public void testRemoveSimpleField() { + Map input = new HashMap<>(); + input.put("field1", "value1"); + input.put("field2", "value2"); + + ProcessorChain.OutputProcessor processor = createRemoveJsonPathProcessor("$.field1"); + Object result = processor.process(input); + + Map resultMap = (Map) result; + assertFalse(resultMap.containsKey("field1")); + assertEquals("value2", resultMap.get("field2")); + } + + @Test + public void testRemoveArrayElement() { + Map input = new HashMap<>(); + List items = new ArrayList<>(); + items.add("item1"); + items.add("item2"); + items.add("item3"); + input.put("items", items); + + ProcessorChain.OutputProcessor processor = createRemoveJsonPathProcessor("$.items[1]"); + Object result = processor.process(input); + + List resultItems = com.jayway.jsonpath.JsonPath.read(StringUtils.toJson(result), "$.items"); + assertEquals(2, resultItems.size()); + assertEquals("item1", resultItems.get(0)); + assertEquals("item3", resultItems.get(1)); + } + + @Test + public void testRemoveNestedObject() { + Map input = new HashMap<>(); + Map nested = new HashMap<>(); + nested.put("innerField", "value"); + input.put("outer", nested); + + ProcessorChain.OutputProcessor processor = createRemoveJsonPathProcessor("$.outer.innerField"); + Object result = processor.process(input); + + Map resultOuter = com.jayway.jsonpath.JsonPath.read(StringUtils.toJson(result), "$.outer"); + assertFalse(resultOuter.containsKey("innerField")); + } + + @Test + public void testRemoveFromNestedArray() { + Map input = new HashMap<>(); + List> items = new ArrayList<>(); + + Map item1 = new HashMap<>(); + item1.put("id", "1"); + item1.put("value", "first"); + + Map item2 = new HashMap<>(); + item2.put("id", "2"); + item2.put("value", "second"); + + items.add(item1); + items.add(item2); + input.put("items", items); + + ProcessorChain.OutputProcessor processor = createRemoveJsonPathProcessor("$.items[0].value"); + Object result = processor.process(input); + + Map firstItem = com.jayway.jsonpath.JsonPath.read(StringUtils.toJson(result), "$.items[0]"); + assertEquals("1", firstItem.get("id")); + assertFalse(firstItem.containsKey("value")); + } + + @Test + public void testRemoveNonExistentPath() { + Map input = new HashMap<>(); + input.put("field", "value"); + + ProcessorChain.OutputProcessor processor = createRemoveJsonPathProcessor("$.nonexistent.path"); + Object result = processor.process(input); + + assertEquals(input, result); + } + + @Test + public void testRemoveWithInvalidInput() { + String input = "not a json object"; + + ProcessorChain.OutputProcessor processor = createRemoveJsonPathProcessor("$.field"); + Object result = processor.process(input); + + assertEquals(input, result); + } + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ListIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ListIndexToolTests.java index 527e57e436..f99454df18 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ListIndexToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ListIndexToolTests.java @@ -132,6 +132,33 @@ public void test_run_successful_2() { verifyResult(tool, createParameters(null, null, null, null)); } + @Test + public void test_run_with_output_parser() { + mockUp(); + Map params = new HashMap<>(); + params.put("output_processors", Arrays.asList(Map.of("type", "regex_replace", "pattern", "index-1", "replacement", "test-index"))); + Tool tool = ListIndexTool.Factory.getInstance().create(params); + + ActionListener listener = mock(ActionListener.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); + tool.run(createParameters("[\"index-1\"]", "true", "10", "true"), listener); + verify(listener).onResponse(captor.capture()); + assert captor.getValue().contains("test-index"); + assert !captor.getValue().contains("index-1"); + } + + @Test + public void test_run_without_output_parser() { + mockUp(); + Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap()); + + ActionListener listener = mock(ActionListener.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(String.class); + tool.run(createParameters("[\"index-1\"]", "true", "10", "true"), listener); + verify(listener).onResponse(captor.capture()); + assert captor.getValue().contains("index-1"); + } + private Map createParameters(String indices, String local, String pageSize, String includeUnloadedSegments) { Map parameters = new HashMap<>(); if (indices != null) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java index 2c6ea41273..1d3cf8725d 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java @@ -237,4 +237,17 @@ public void testToolWithNullModelId() { public void testToolWithBlankModelId() { assertThrows(IllegalArgumentException.class, () -> new MLModelTool(client, "", "response")); } + + @Test + public void testFactoryCreateWithProcessorEnhancement() { + Map toolParams = Map + .of("model_id", "test_model_id", "response_field", "custom_response", "processor_configs", "[{\"type\":\"test_processor\"}]"); + + MLModelTool tool = MLModelTool.Factory.getInstance().create(toolParams); + + assertEquals("test_model_id", tool.getModelId()); + assertEquals("custom_response", tool.getResponseField()); + // Verify that the output parser was enhanced (not null and different from basic parser) + assertTrue(tool.getOutputParser() != null); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java index e7e21ecca9..161c43a92f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java @@ -411,6 +411,82 @@ public void testRunWithoutReturnFullResponse() { assertFalse(((String) result).contains("took")); } + @Test + @SneakyThrows + public void testRunWithOutputParserForModelTensorOutput() { + SearchResponse mockedSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(any(), any()); + + mockedSearchIndexTool.setOutputParser(output -> "parsed_output"); + + String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}"; + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> future.complete(r), e -> future.completeExceptionally(e)); + + Map parameters = new HashMap<>(); + parameters.put("input", inputString); + parameters.put(SearchIndexTool.RETURN_RAW_RESPONSE, "true"); + + mockedSearchIndexTool.run(parameters, listener); + + Object result = future.join(); + assertEquals("parsed_output", result); + } + + @Test + @SneakyThrows + public void testRunWithOutputParserForStringResponse() { + SearchResponse mockedSearchResponse = SearchResponse + .fromXContent( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, mockedSearchResponseString) + ); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mockedSearchResponse); + return null; + }).when(client).search(any(), any()); + + mockedSearchIndexTool.setOutputParser(output -> "parsed_string_output"); + + String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}}"; + final CompletableFuture future = new CompletableFuture<>(); + ActionListener listener = ActionListener.wrap(r -> future.complete(r), e -> future.completeExceptionally(e)); + + Map parameters = new HashMap<>(); + parameters.put("input", inputString); + parameters.put(SearchIndexTool.RETURN_RAW_RESPONSE, "false"); + + mockedSearchIndexTool.run(parameters, listener); + + Object result = future.join(); + assertEquals("parsed_string_output", result); + } + + @Test + public void testFactoryCreateWithProcessorEnhancement() { + SearchIndexTool.Factory.getInstance().init(client, TEST_XCONTENT_REGISTRY_FOR_QUERY); + + Map params = new HashMap<>(); + params.put("processor_configs", "[{\"type\":\"test_processor\"}]"); + + SearchIndexTool tool = SearchIndexTool.Factory.getInstance().create(params); + + assertEquals(SearchIndexTool.TYPE, tool.getType()); + // Verify that the output parser was set (not null) + assertTrue(tool.getOutputParser() != null); + } + @Test @SneakyThrows public void testRun_withMatchQuery_triggersPlainDoubleGson() { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/parser/ToolParserTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/parser/ToolParserTests.java new file mode 100644 index 0000000000..2d10bf699f --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/parser/ToolParserTests.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.tools.parser; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.ml.common.spi.tools.Parser; + +public class ToolParserTests { + + @Test + public void testCreateProcessingParserWithBaseParser() { + Parser baseParser = input -> "base_" + input; + List> processorConfigs = Collections.emptyList(); + + Parser parser = ToolParser.createProcessingParser(baseParser, processorConfigs); + + assertNotNull(parser); + Object result = parser.parse("input"); + assertEquals("base_input", result); + } + + @Test + public void testCreateProcessingParserWithoutBaseParser() { + List> processorConfigs = Collections.emptyList(); + + Parser parser = ToolParser.createProcessingParser(null, processorConfigs); + + assertNotNull(parser); + Object result = parser.parse("input"); + assertEquals("input", result); + } + + @Test + public void testCreateProcessingParserWithEmptyProcessors() { + Parser baseParser = input -> "base_" + input; + List> processorConfigs = Collections.emptyList(); + + Parser parser = ToolParser.createProcessingParser(baseParser, processorConfigs); + + assertNotNull(parser); + Object result = parser.parse("input"); + assertEquals("base_input", result); + } + + @Test + public void testCreateFromToolParamsWithBaseParser() { + Parser baseParser = input -> "base_" + input; + Map params = Collections.emptyMap(); + + Parser parser = ToolParser.createFromToolParams(params, baseParser); + + assertNotNull(parser); + Object result = parser.parse("input"); + assertEquals("base_input", result); + } + + @Test + public void testCreateFromToolParamsWithoutBaseParser() { + Map params = Collections.emptyMap(); + + Parser parser = ToolParser.createFromToolParams(params); + + assertNotNull(parser); + Object result = parser.parse("input"); + assertEquals("input", result); + } + + @Test + public void testCreateFromToolParamsWithEmptyParams() { + Map params = Collections.emptyMap(); + + Parser parser = ToolParser.createFromToolParams(params); + + assertNotNull(parser); + Object result = parser.parse("input"); + assertEquals("input", result); + } + + @Test + public void testCreateFromToolParamsWithNullParams() { + Parser parser = ToolParser.createFromToolParams(null); + + assertNotNull(parser); + Object result = parser.parse("input"); + assertEquals("input", result); + } +}