diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 351171ede6..5570b14ee1 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -97,6 +97,7 @@ public class CommonValue { public static final Version VERSION_3_1_0 = Version.fromString("3.1.0"); public static final Version VERSION_3_2_0 = Version.fromString("3.2.0"); public static final Version VERSION_3_3_0 = Version.fromString("3.3.0"); + public static final Version VERSION_3_4_0 = Version.fromString("3.4.0"); // Connector Constants public static final String NAME_FIELD = "name"; @@ -113,6 +114,7 @@ public class CommonValue { public static final String CLIENT_CONFIG_FIELD = "client_config"; public static final String URL_FIELD = "url"; public static final String HEADERS_FIELD = "headers"; + public static final String CONNECTOR_ACTION_FIELD = "connector_action"; // MCP Constants public static final String MCP_TOOL_NAME_FIELD = "name"; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index 9a035230a0..05f2d3781b 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -121,8 +121,11 @@ public void parseResponse(T response, List modelTensors, boolea @Override public Optional findAction(String action) { - if (actions != null) { - return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst(); + if (actions != null && action != null) { + if (ConnectorAction.ActionType.isValidAction(action)) { + return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst(); + } + return actions.stream().filter(a -> action.equals(a.getName())).findFirst(); } return Optional.empty(); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index c82f489296..8ff29a44f2 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -6,8 +6,10 @@ package org.opensearch.ml.common.connector; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.VERSION_3_4_0; import java.io.IOException; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Locale; @@ -33,6 +35,7 @@ public class ConnectorAction implements ToXContentObject, Writeable { public static final String ACTION_TYPE_FIELD = "action_type"; + public static final String NAME_FIELD = "name"; public static final String METHOD_FIELD = "method"; public static final String URL_FIELD = "url"; public static final String HEADERS_FIELD = "headers"; @@ -52,6 +55,7 @@ public class ConnectorAction implements ToXContentObject, Writeable { private static final Logger logger = LogManager.getLogger(ConnectorAction.class); private ActionType actionType; + private String name; private String method; private String url; private Map headers; @@ -62,6 +66,7 @@ public class ConnectorAction implements ToXContentObject, Writeable { @Builder(toBuilder = true) public ConnectorAction( ActionType actionType, + String name, String method, String url, Map headers, @@ -78,7 +83,15 @@ public ConnectorAction( if (method == null) { throw new IllegalArgumentException("method can't be null"); } + // The 'name' field is an optional identifier for this specific action within a connector. + // It allows running a specific action by name when a connector has multiple actions of the same type. + // We validate that 'name' is not an action type (PREDICT, EXECUTE, etc.) to avoid ambiguity + // when resolving actions. + if (name != null && ActionType.isValidAction(name)) { + throw new IllegalArgumentException("name can't be one of action type " + Arrays.toString(ActionType.values())); + } this.actionType = actionType; + this.name = name; this.method = method; this.url = url; this.headers = headers; @@ -97,6 +110,9 @@ public ConnectorAction(StreamInput input) throws IOException { this.requestBody = input.readOptionalString(); this.preProcessFunction = input.readOptionalString(); this.postProcessFunction = input.readOptionalString(); + if (input.getVersion().onOrAfter(VERSION_3_4_0)) { + this.name = input.readOptionalString(); + } } @Override @@ -113,6 +129,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(requestBody); out.writeOptionalString(preProcessFunction); out.writeOptionalString(postProcessFunction); + if (out.getVersion().onOrAfter(VERSION_3_4_0)) { + out.writeOptionalString(name); + } } @Override @@ -139,6 +158,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params if (postProcessFunction != null) { builder.field(ACTION_POST_PROCESS_FUNCTION, postProcessFunction); } + if (name != null) { + builder.field(NAME_FIELD, name); + } return builder.endObject(); } @@ -149,6 +171,7 @@ public static ConnectorAction fromStream(StreamInput in) throws IOException { public static ConnectorAction parse(XContentParser parser) throws IOException { ActionType actionType = null; + String name = null; String method = null; String url = null; Map headers = null; @@ -165,6 +188,9 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { case ACTION_TYPE_FIELD: actionType = ActionType.valueOf(parser.text().toUpperCase(Locale.ROOT)); break; + case NAME_FIELD: + name = parser.text(); + break; case METHOD_FIELD: method = parser.text(); break; @@ -191,6 +217,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { return ConnectorAction .builder() .actionType(actionType) + .name(name) .method(method) .url(url) .headers(headers) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 53f66ce384..c93a8e6abb 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -13,6 +13,7 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.isJson; +import static org.opensearch.ml.common.utils.StringUtils.isJsonOrNdjson; import static org.opensearch.ml.common.utils.StringUtils.parseParameters; import java.io.IOException; @@ -358,12 +359,14 @@ public T createPayload(String action, Map parameters) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); - if (!isJson(payload)) { + if (!isJsonOrNdjson(payload)) { throw new IllegalArgumentException("Invalid payload: " + payload); } else if (neededStreamParameterInPayload(parameters)) { - JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject(); - jsonObject.addProperty("stream", true); - payload = jsonObject.toString(); + if (isJson(payload)) { + JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject(); + jsonObject.addProperty("stream", true); + payload = jsonObject.toString(); + } } return (T) payload; } diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index a3f1a3b416..039b1c2bd1 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -136,6 +136,40 @@ public static boolean isJson(String json) { } } + /** + * Checks if the given string is valid JSON or NDJSON (newline-delimited JSON). + * NDJSON is commonly used for bulk operations in OpenSearch where each line is a separate JSON object. + * + * @param json the string to validate + * @return true if the string is valid JSON or NDJSON, false otherwise + */ + public static boolean isJsonOrNdjson(String json) { + if (json == null || json.isBlank()) { + return false; + } + + // First check if it's regular JSON + if (isJson(json)) { + return true; + } + + // Check if it's NDJSON (newline-delimited JSON) + String[] lines = json.split("\\r?\\n"); + if (lines.length == 0) { + return false; + } + + // Each non-empty line must be valid JSON + for (String line : lines) { + String trimmedLine = line.trim(); + if (!trimmedLine.isEmpty() && !isJson(trimmedLine)) { + return false; + } + } + + return true; + } + /** * Ensures that a string is properly JSON escaped. * diff --git a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java index 2b679b8bbe..63033a2316 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java @@ -213,6 +213,7 @@ private AwsConnector createAwsConnector() { private AwsConnector createAwsConnector(Map parameters, Map credential, String url) { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String name = null; String method = "POST"; Map headers = new HashMap<>(); headers.put("api_key", "${credential.key}"); @@ -222,6 +223,7 @@ private AwsConnector createAwsConnector(Map parameters, Map new ConnectorAction(null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null) + () -> new ConnectorAction(null, null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null) ); assertEquals("action type can't be null", exception.getMessage()); @@ -109,7 +109,7 @@ public void constructor_NullActionType() { public void constructor_NullUrl() { Throwable exception = assertThrows( IllegalArgumentException.class, - () -> new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null) + () -> new ConnectorAction(TEST_ACTION_TYPE, null, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null) ); assertEquals("url can't be null", exception.getMessage()); } @@ -118,14 +118,23 @@ public void constructor_NullUrl() { public void constructor_NullMethod() { Throwable exception = assertThrows( IllegalArgumentException.class, - () -> new ConnectorAction(TEST_ACTION_TYPE, null, URL, null, TEST_REQUEST_BODY, null, null) + () -> new ConnectorAction(TEST_ACTION_TYPE, null, null, URL, null, TEST_REQUEST_BODY, null, null) ); assertEquals("method can't be null", exception.getMessage()); } @Test public void testValidatePrePostProcessFunctionsWithNullPreProcessFunctionSuccess() { - ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, OPENAI_URL, null, TEST_REQUEST_BODY, null, null); + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + null, + TEST_METHOD_HTTP, + OPENAI_URL, + null, + TEST_REQUEST_BODY, + null, + null + ); action.validatePrePostProcessFunctions(Map.of()); assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); } @@ -134,6 +143,7 @@ public void testValidatePrePostProcessFunctionsWithNullPreProcessFunctionSuccess public void testValidatePrePostProcessFunctionsWithExternalServers() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, URL, null, @@ -151,6 +161,7 @@ public void testValidatePrePostProcessFunctionsWithCustomPainlessScriptPreProces "\"\\n StringBuilder builder = new StringBuilder();\\n builder.append(\\\"\\\\\\\"\\\");\\n String first = params.text_docs[0];\\n builder.append(first);\\n builder.append(\\\"\\\\\\\"\\\");\\n def parameters = \\\"{\\\" +\\\"\\\\\\\"text_inputs\\\\\\\":\\\" + builder + \\\"}\\\";\\n return \\\"{\\\" +\\\"\\\\\\\"parameters\\\\\\\":\\\" + parameters + \\\"}\\\";\""; ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, OPENAI_URL, null, @@ -166,6 +177,7 @@ public void testValidatePrePostProcessFunctionsWithCustomPainlessScriptPreProces public void testValidatePrePostProcessFunctionsWithOpenAIConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, "https://${parameters.endpoint}/v1/chat/completions", null, @@ -181,6 +193,7 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorCorrectInBuilt public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, OPENAI_URL, null, @@ -206,6 +219,7 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPr public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, OPENAI_URL, null, @@ -231,6 +245,7 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPo public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -243,6 +258,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -255,6 +271,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -270,6 +287,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -295,6 +313,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPr public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -320,6 +339,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPo public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -332,6 +352,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuil action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -344,6 +365,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuil action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -359,6 +381,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuil public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -384,6 +407,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -409,6 +433,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWithCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -421,6 +446,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWithCorrect action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -436,6 +462,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWithCorrect public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -463,6 +490,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuil public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -488,7 +516,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuil @Test public void writeTo_NullValue() throws IOException { - ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, null, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); @@ -504,6 +532,7 @@ public void writeTo() throws IOException { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, URL, headers, @@ -519,7 +548,7 @@ public void writeTo() throws IOException { @Test public void toXContent_NullValue() throws IOException { - ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, null, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -540,6 +569,7 @@ public void toXContent() throws IOException { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, URL, headers, diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 1038006f2c..9d5d69dac3 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -15,6 +15,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.function.BiFunction; import org.junit.Assert; @@ -379,6 +380,7 @@ public static HttpConnector createHttpConnector() { public static HttpConnector createHttpConnectorWithRequestBody(String requestBody) { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String name = null; String method = "POST"; String url = "https://test.com"; Map headers = new HashMap<>(); @@ -388,6 +390,7 @@ public static HttpConnector createHttpConnectorWithRequestBody(String requestBod ConnectorAction action = new ConnectorAction( actionType, + name, method, url, headers, @@ -531,4 +534,154 @@ public void testParseResponse_NonStringNonMapResponse() throws IOException { Assert.assertEquals(42, modelTensors.get(0).getDataAsMap().get("response")); } + @Test + public void testFindAction_WithValidActionType() { + HttpConnector connector = createHttpConnector(); + Optional action = connector.findAction("PREDICT"); + Assert.assertTrue(action.isPresent()); + Assert.assertEquals(PREDICT, action.get().getActionType()); + } + + @Test + public void testFindAction_WithValidActionTypeCaseInsensitive() { + HttpConnector connector = createHttpConnector(); + Optional action = connector.findAction("predict"); + Assert.assertTrue(action.isPresent()); + Assert.assertEquals(PREDICT, action.get().getActionType()); + } + + @Test + public void testFindAction_WithCustomActionName() { + String customActionName = "custom_action"; + ConnectorAction customAction = new ConnectorAction( + PREDICT, + customActionName, + "POST", + "https://custom.com", + null, + "{\"input\": \"test\"}", + null, + null + ); + + HttpConnector connector = HttpConnector + .builder() + .name("test_connector") + .protocol("http") + .actions(Arrays.asList(customAction)) + .build(); + + Optional action = connector.findAction(customActionName); + Assert.assertTrue(action.isPresent()); + Assert.assertEquals(customActionName, action.get().getName()); + Assert.assertEquals(PREDICT, action.get().getActionType()); + } + + @Test + public void testFindAction_WithNullAction() { + HttpConnector connector = createHttpConnector(); + Optional action = connector.findAction(null); + Assert.assertFalse(action.isPresent()); + } + + @Test + public void testFindAction_WithInvalidActionType() { + HttpConnector connector = createHttpConnector(); + Optional action = connector.findAction("INVALID_ACTION"); + Assert.assertFalse(action.isPresent()); + } + + @Test + public void testFindAction_WithNullActions() { + HttpConnector connector = HttpConnector.builder().name("test_connector").protocol("http").actions(null).build(); + Optional action = connector.findAction("PREDICT"); + Assert.assertFalse(action.isPresent()); + } + + @Test + public void testFindAction_CustomNameTakesPrecedenceOverActionType() { + String customActionName = "my_predict"; + ConnectorAction action1 = new ConnectorAction( + PREDICT, + null, + "POST", + "https://test1.com", + null, + "{\"input\": \"test1\"}", + null, + null + ); + ConnectorAction action2 = new ConnectorAction( + ConnectorAction.ActionType.EXECUTE, + customActionName, + "POST", + "https://test2.com", + null, + "{\"input\": \"test2\"}", + null, + null + ); + + HttpConnector connector = HttpConnector + .builder() + .name("test_connector") + .protocol("http") + .actions(Arrays.asList(action1, action2)) + .build(); + + // When searching by valid action type, should find by action type first + Optional foundByType = connector.findAction("PREDICT"); + Assert.assertTrue(foundByType.isPresent()); + Assert.assertEquals(PREDICT, foundByType.get().getActionType()); + Assert.assertEquals("https://test1.com", foundByType.get().getUrl()); + + // When searching by custom name, should find by name + Optional foundByName = connector.findAction(customActionName); + Assert.assertTrue(foundByName.isPresent()); + Assert.assertEquals(customActionName, foundByName.get().getName()); + Assert.assertEquals("https://test2.com", foundByName.get().getUrl()); + } + + @Test + public void testFindAction_MultipleActionsWithSameType() { + ConnectorAction action1 = new ConnectorAction( + PREDICT, + "predict_action_1", + "POST", + "https://test1.com", + null, + "{\"input\": \"test1\"}", + null, + null + ); + ConnectorAction action2 = new ConnectorAction( + PREDICT, + "predict_action_2", + "POST", + "https://test2.com", + null, + "{\"input\": \"test2\"}", + null, + null + ); + + HttpConnector connector = HttpConnector + .builder() + .name("test_connector") + .protocol("http") + .actions(Arrays.asList(action1, action2)) + .build(); + + // Should return the first matching action when searching by type + Optional foundByType = connector.findAction("PREDICT"); + Assert.assertTrue(foundByType.isPresent()); + Assert.assertEquals("predict_action_1", foundByType.get().getName()); + + // Should find specific action by custom name + Optional foundByName = connector.findAction("predict_action_2"); + Assert.assertTrue(foundByName.isPresent()); + Assert.assertEquals("predict_action_2", foundByName.get().getName()); + Assert.assertEquals("https://test2.com", foundByName.get().getUrl()); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index a7df00618a..b84caceb59 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -73,6 +73,7 @@ public class MLCreateConnectorInputTests { @Before public void setUp() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String name = null; String method = "POST"; String url = "https://test.com"; Map headers = new HashMap<>(); @@ -82,6 +83,7 @@ public void setUp() { String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( actionType, + name, method, url, headers, diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java index b4f7629689..9dbb083e76 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java @@ -39,6 +39,7 @@ public class MLCreateConnectorRequestTests { @Before public void setUp() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String name = null; String method = "POST"; String url = "https://test.com"; Map headers = new HashMap<>(); @@ -48,6 +49,7 @@ public void setUp() { String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( actionType, + name, method, url, headers, diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index e81ccc54a3..4d6e2b7ecd 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -77,6 +77,95 @@ public void isJson_False() { assertFalse(StringUtils.isJson("[abc\n123]")); } + @Test + public void isJsonOrNdjson_NullInput() { + assertFalse(StringUtils.isJsonOrNdjson(null)); + } + + @Test + public void isJsonOrNdjson_BlankInput() { + assertFalse(StringUtils.isJsonOrNdjson("")); + assertFalse(StringUtils.isJsonOrNdjson(" ")); + assertFalse(StringUtils.isJsonOrNdjson("\n")); + assertFalse(StringUtils.isJsonOrNdjson("\t")); + } + + @Test + public void isJsonOrNdjson_ValidJson() { + // Valid JSON objects should return true + assertTrue(StringUtils.isJsonOrNdjson("{}")); + assertTrue(StringUtils.isJsonOrNdjson("[]")); + assertTrue(StringUtils.isJsonOrNdjson("{\"key\": \"value\"}")); + assertTrue(StringUtils.isJsonOrNdjson("{\"key\": 123}")); + assertTrue(StringUtils.isJsonOrNdjson("[1, 2, 3]")); + assertTrue(StringUtils.isJsonOrNdjson("[\"a\", \"b\"]")); + assertTrue(StringUtils.isJsonOrNdjson("{\"key1\": \"value\", \"key2\": 123}")); + } + + @Test + public void isJsonOrNdjson_ValidNdjson() { + // Valid NDJSON (newline-delimited JSON) should return true + assertTrue(StringUtils.isJsonOrNdjson("{\"key1\": \"value1\"}\n{\"key2\": \"value2\"}")); + assertTrue(StringUtils.isJsonOrNdjson("{\"index\": {}}\n{\"field\": \"value\"}")); + assertTrue(StringUtils.isJsonOrNdjson("[1, 2, 3]\n[4, 5, 6]")); + assertTrue(StringUtils.isJsonOrNdjson("{\"a\": 1}\n{\"b\": 2}\n{\"c\": 3}")); + } + + @Test + public void isJsonOrNdjson_ValidNdjsonWithCarriageReturn() { + // NDJSON with \r\n line endings should return true + assertTrue(StringUtils.isJsonOrNdjson("{\"key1\": \"value1\"}\r\n{\"key2\": \"value2\"}")); + assertTrue(StringUtils.isJsonOrNdjson("{\"a\": 1}\r\n{\"b\": 2}\r\n{\"c\": 3}")); + } + + @Test + public void isJsonOrNdjson_NdjsonWithEmptyLines() { + // NDJSON with empty lines should return true (empty lines are ignored) + assertTrue(StringUtils.isJsonOrNdjson("{\"key1\": \"value1\"}\n\n{\"key2\": \"value2\"}")); + assertTrue(StringUtils.isJsonOrNdjson("{\"a\": 1}\n \n{\"b\": 2}")); + assertTrue(StringUtils.isJsonOrNdjson("\n{\"key\": \"value\"}\n")); + assertTrue(StringUtils.isJsonOrNdjson("{\"key\": \"value\"}\n\n")); + } + + @Test + public void isJsonOrNdjson_InvalidJson() { + // Invalid JSON should return false + assertFalse(StringUtils.isJsonOrNdjson("{")); + assertFalse(StringUtils.isJsonOrNdjson("[")); + assertFalse(StringUtils.isJsonOrNdjson("{\"key\": \"value}")); + assertFalse(StringUtils.isJsonOrNdjson("[1, \"a]")); + assertFalse(StringUtils.isJsonOrNdjson("not json")); + assertFalse(StringUtils.isJsonOrNdjson("123abc")); + } + + @Test + public void isJsonOrNdjson_InvalidNdjson() { + // NDJSON with at least one invalid line should return false + assertFalse(StringUtils.isJsonOrNdjson("{\"key1\": \"value1\"}\ninvalid json")); + assertFalse(StringUtils.isJsonOrNdjson("{\"key1\": \"value1\"}\n{\"key2\": \"value2}\n{\"key3\": \"value3\"}")); + assertFalse(StringUtils.isJsonOrNdjson("invalid\n{\"key\": \"value\"}")); + assertFalse(StringUtils.isJsonOrNdjson("{\"a\": 1}\n{\"b\": 2\n{\"c\": 3}")); + } + + @Test + public void isJsonOrNdjson_MixedValidInvalidLines() { + // Mix of valid and invalid JSON lines should return false + assertFalse(StringUtils.isJsonOrNdjson("{\"valid\": true}\n{invalid}\n{\"also_valid\": true}")); + assertFalse(StringUtils.isJsonOrNdjson("[1, 2, 3]\nplain text\n[4, 5, 6]")); + } + + @Test + public void isJsonOrNdjson_OpenSearchBulkFormat() { + // OpenSearch bulk API format (action/metadata line followed by document) + assertTrue(StringUtils.isJsonOrNdjson("{\"index\": {\"_index\": \"test\"}}\n{\"field\": \"value\"}")); + assertTrue( + StringUtils + .isJsonOrNdjson( + "{\"index\": {\"_index\": \"test\", \"_id\": \"1\"}}\n{\"field1\": \"value1\"}\n{\"index\": {\"_index\": \"test\", \"_id\": \"2\"}}\n{\"field2\": \"value2\"}" + ) + ); + } + @Test public void toUTF8() { String rawString = "\uD83D\uDE00\uD83D\uDE0D\uD83D\uDE1C"; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 3b53935aaf..f500ae32d1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -9,8 +9,10 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; +import static software.amazon.awssdk.http.SdkHttpMethod.DELETE; import static software.amazon.awssdk.http.SdkHttpMethod.GET; import static software.amazon.awssdk.http.SdkHttpMethod.POST; +import static software.amazon.awssdk.http.SdkHttpMethod.PUT; import java.security.AccessController; import java.security.PrivilegedExceptionAction; @@ -110,7 +112,13 @@ public void invokeRemoteService( request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST); break; case "GET": - request = ConnectorUtils.buildSdkRequest(action, connector, parameters, null, GET); + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, GET); + break; + case "PUT": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, PUT); + break; + case "DELETE": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, DELETE); break; default: throw new IllegalArgumentException("unsupported http method"); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 7804770258..45b318bc6c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -8,8 +8,10 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; +import static software.amazon.awssdk.http.SdkHttpMethod.DELETE; import static software.amazon.awssdk.http.SdkHttpMethod.GET; import static software.amazon.awssdk.http.SdkHttpMethod.POST; +import static software.amazon.awssdk.http.SdkHttpMethod.PUT; import java.security.AccessController; import java.security.PrivilegedExceptionAction; @@ -109,7 +111,13 @@ public void invokeRemoteService( request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST); break; case "GET": - request = ConnectorUtils.buildSdkRequest(action, connector, parameters, null, GET); + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, GET); + break; + case "PUT": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, PUT); + break; + case "DELETE": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, DELETE); break; default: throw new IllegalArgumentException("unsupported http method"); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java index c1b5db1778..54be390178 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java @@ -5,6 +5,7 @@ package org.opensearch.ml.action.connector; +import static org.opensearch.ml.common.CommonValue.CONNECTOR_ACTION_FIELD; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import org.opensearch.ResourceNotFoundException; @@ -18,6 +19,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; @@ -73,14 +75,24 @@ public ExecuteConnectorTransportAction( protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.fromActionRequest(request); String connectorId = executeConnectorRequest.getConnectorId(); + if (executeConnectorRequest.getMlInput() == null) { + actionListener.onFailure(new IllegalArgumentException("MLInput cannot be null")); + return; + } + + RemoteInferenceInputDataSet inputDataset = (RemoteInferenceInputDataSet) executeConnectorRequest.getMlInput().getInputDataset(); String connectorAction = ConnectorAction.ActionType.EXECUTE.name(); + if (inputDataset.getParameters() != null && inputDataset.getParameters().get(CONNECTOR_ACTION_FIELD) != null) { + connectorAction = inputDataset.getParameters().get(CONNECTOR_ACTION_FIELD); + } if (MLIndicesHandler .doesMultiTenantIndexExist(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), ML_CONNECTOR_INDEX)) { + String finalConnectorAction = connectorAction; ActionListener listener = ActionListener.wrap(connector -> { if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { // adding tenantID as null, because we are not implement multi-tenancy for this feature yet. - connector.decrypt(connectorAction, (credential, tenantId) -> encryptor.decrypt(credential, null), null); + connector.decrypt(finalConnectorAction, (credential, tenantId) -> encryptor.decrypt(credential, null), null); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader .initInstance(connector.getProtocol(), connector, Connector.class); connectorExecutor.setConnectorPrivateIpEnabled(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()); @@ -89,7 +101,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + .executeAction(finalConnectorAction, executeConnectorRequest.getMlInput(), ActionListener.wrap(taskResponse -> { actionListener.onResponse(taskResponse); }, e -> { actionListener.onFailure(e); })); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java index 4383cc0f86..6d401a8a3a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportActionTests.java @@ -119,6 +119,10 @@ public void setup() { public void testExecute_NoConnectorIndex() { when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); + when(request.getMlInput()).thenReturn(org.opensearch.ml.common.input.MLInput.builder() + .algorithm(org.opensearch.ml.common.FunctionName.REMOTE) + .inputDataset(new org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet(Map.of(), null)) + .build()); action.doExecute(task, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(argCaptor.capture()); @@ -128,6 +132,10 @@ public void testExecute_NoConnectorIndex() { public void testExecute_FailedToGetConnector() { when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(true); when(metaData.hasIndex(anyString())).thenReturn(true); + when(request.getMlInput()).thenReturn(org.opensearch.ml.common.input.MLInput.builder() + .algorithm(org.opensearch.ml.common.FunctionName.REMOTE) + .inputDataset(new org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet(Map.of(), null)) + .build()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2);