Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ public class CommonValue {
public static final Version VERSION_3_1_0 = Version.fromString("3.1.0");
public static final Version VERSION_3_2_0 = Version.fromString("3.2.0");
public static final Version VERSION_3_3_0 = Version.fromString("3.3.0");
public static final Version VERSION_3_4_0 = Version.fromString("3.4.0");

// Connector Constants
public static final String NAME_FIELD = "name";
Expand All @@ -113,6 +114,7 @@ public class CommonValue {
public static final String CLIENT_CONFIG_FIELD = "client_config";
public static final String URL_FIELD = "url";
public static final String HEADERS_FIELD = "headers";
public static final String CONNECTOR_ACTION_FIELD = "connector_action";

// MCP Constants
public static final String MCP_TOOL_NAME_FIELD = "name";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,11 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea

@Override
public Optional<ConnectorAction> findAction(String action) {
if (actions != null) {
return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst();
if (actions != null && action != null) {
if (ConnectorAction.ActionType.isValidAction(action)) {
return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst();
}
return actions.stream().filter(a -> action.equals(a.getName())).findFirst();
}
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
package org.opensearch.ml.common.connector;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.VERSION_3_4_0;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
Expand All @@ -33,6 +35,7 @@
public class ConnectorAction implements ToXContentObject, Writeable {

public static final String ACTION_TYPE_FIELD = "action_type";
public static final String NAME_FIELD = "name";
public static final String METHOD_FIELD = "method";
public static final String URL_FIELD = "url";
public static final String HEADERS_FIELD = "headers";
Expand All @@ -52,6 +55,7 @@ public class ConnectorAction implements ToXContentObject, Writeable {
private static final Logger logger = LogManager.getLogger(ConnectorAction.class);

private ActionType actionType;
private String name;
private String method;
private String url;
private Map<String, String> headers;
Expand All @@ -62,6 +66,7 @@ public class ConnectorAction implements ToXContentObject, Writeable {
@Builder(toBuilder = true)
public ConnectorAction(
ActionType actionType,
String name,
String method,
String url,
Map<String, String> headers,
Expand All @@ -78,7 +83,11 @@ public ConnectorAction(
if (method == null) {
throw new IllegalArgumentException("method can't be null");
}
if (name != null && ActionType.isValidAction(name)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment what is this name for and also why are we checking name not to be an action?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

throw new IllegalArgumentException("name can't be one of action type " + Arrays.toString(ActionType.values()));
}
this.actionType = actionType;
this.name = name;
this.method = method;
this.url = url;
this.headers = headers;
Expand All @@ -97,6 +106,9 @@ public ConnectorAction(StreamInput input) throws IOException {
this.requestBody = input.readOptionalString();
this.preProcessFunction = input.readOptionalString();
this.postProcessFunction = input.readOptionalString();
if (input.getVersion().onOrAfter(VERSION_3_4_0)) {
this.name = input.readOptionalString();
}
}

@Override
Expand All @@ -113,6 +125,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(requestBody);
out.writeOptionalString(preProcessFunction);
out.writeOptionalString(postProcessFunction);
if (out.getVersion().onOrAfter(VERSION_3_4_0)) {
out.writeOptionalString(name);
}
}

@Override
Expand All @@ -139,6 +154,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (postProcessFunction != null) {
builder.field(ACTION_POST_PROCESS_FUNCTION, postProcessFunction);
}
if (name != null) {
builder.field(NAME_FIELD, name);
}
return builder.endObject();
}

Expand All @@ -149,6 +167,7 @@ public static ConnectorAction fromStream(StreamInput in) throws IOException {

public static ConnectorAction parse(XContentParser parser) throws IOException {
ActionType actionType = null;
String name = null;
String method = null;
String url = null;
Map<String, String> headers = null;
Expand All @@ -165,6 +184,9 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
case ACTION_TYPE_FIELD:
actionType = ActionType.valueOf(parser.text().toUpperCase(Locale.ROOT));
break;
case NAME_FIELD:
name = parser.text();
break;
case METHOD_FIELD:
method = parser.text();
break;
Expand All @@ -191,6 +213,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException {
return ConnectorAction
.builder()
.actionType(actionType)
.name(name)
.method(method)
.url(url)
.headers(headers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
import static org.opensearch.ml.common.utils.StringUtils.isJsonOrNdjson;
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;

import java.io.IOException;
Expand Down Expand Up @@ -358,12 +359,14 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

if (!isJson(payload)) {
if (!isJsonOrNdjson(payload)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
} else if (neededStreamParameterInPayload(parameters)) {
JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject();
jsonObject.addProperty("stream", true);
payload = jsonObject.toString();
if (isJson(payload)) {
JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject();
jsonObject.addProperty("stream", true);
payload = jsonObject.toString();
}
}
return (T) payload;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,40 @@ public static boolean isJson(String json) {
}
}

/**
* Checks if the given string is valid JSON or NDJSON (newline-delimited JSON).
* NDJSON is commonly used for bulk operations in OpenSearch where each line is a separate JSON object.
*
* @param json the string to validate
* @return true if the string is valid JSON or NDJSON, false otherwise
*/
public static boolean isJsonOrNdjson(String json) {
if (json == null || json.isBlank()) {
return false;
}

// First check if it's regular JSON
if (isJson(json)) {
return true;
}

// Check if it's NDJSON (newline-delimited JSON)
String[] lines = json.split("\\r?\\n");
if (lines.length == 0) {
return false;
}

// Each non-empty line must be valid JSON
for (String line : lines) {
String trimmedLine = line.trim();
if (!trimmedLine.isEmpty() && !isJson(trimmedLine)) {
return false;
}
}

return true;
}

/**
* Ensures that a string is properly JSON escaped.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ private AwsConnector createAwsConnector() {

private AwsConnector createAwsConnector(Map<String, String> parameters, Map<String, String> credential, String url) {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String name = null;
String method = "POST";
Map<String, String> headers = new HashMap<>();
headers.put("api_key", "${credential.key}");
Expand All @@ -222,6 +223,7 @@ private AwsConnector createAwsConnector(Map<String, String> parameters, Map<Stri

ConnectorAction action = new ConnectorAction(
actionType,
name,
method,
url,
headers,
Expand Down
Loading
Loading