Skip to content

Commit ae40e21

Browse files
Unified Nova MME
Signed-off-by: Nathalie Jonathan <[email protected]>
1 parent c243f8a commit ae40e21

File tree

5 files changed

+286
-0
lines changed

5 files changed

+286
-0
lines changed

common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.opensearch.ml.common.connector.functions.preprocess.CohereRerankPreProcessFunction;
1818
import org.opensearch.ml.common.connector.functions.preprocess.ImageEmbeddingPreProcessFunction;
1919
import org.opensearch.ml.common.connector.functions.preprocess.MultiModalConnectorPreProcessFunction;
20+
import org.opensearch.ml.common.connector.functions.preprocess.NovaMultiModalEmbeddingPreProcessFunction;
2021
import org.opensearch.ml.common.connector.functions.preprocess.OpenAIEmbeddingPreProcessFunction;
2122
import org.opensearch.ml.common.connector.functions.preprocess.VideoEmbeddingPreProcessFunction;
2223
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
@@ -30,6 +31,7 @@ public class MLPreProcessFunction {
3031
public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding";
3132
public static final String TEXT_DOCS_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.embedding";
3233
public static final String TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT = "connector.pre_process.bedrock.multimodal_embedding";
34+
public static final String BEDROCK_NOVA_MULTI_MODAL_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.embedding";
3335
public static final String TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.text_embedding";
3436
public static final String IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.image_embedding";
3537
public static final String VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT = "connector.pre_process.bedrock.nova.video_embedding";
@@ -49,6 +51,8 @@ public class MLPreProcessFunction {
4951
CohereRerankPreProcessFunction cohereRerankPreProcessFunction = new CohereRerankPreProcessFunction();
5052
BedrockRerankPreProcessFunction bedrockRerankPreProcessFunction = new BedrockRerankPreProcessFunction();
5153
MultiModalConnectorPreProcessFunction multiModalEmbeddingPreProcessFunction = new MultiModalConnectorPreProcessFunction();
54+
NovaMultiModalEmbeddingPreProcessFunction novaMultiModalEmbeddingPreProcessFunction =
55+
new NovaMultiModalEmbeddingPreProcessFunction();
5256
ImageEmbeddingPreProcessFunction imageEmbeddingPreProcessFunction = new ImageEmbeddingPreProcessFunction();
5357
VideoEmbeddingPreProcessFunction videoEmbeddingPreProcessFunction = new VideoEmbeddingPreProcessFunction();
5458
AudioEmbeddingPreProcessFunction audioEmbeddingPreProcessFunction = new AudioEmbeddingPreProcessFunction();
@@ -57,6 +61,7 @@ public class MLPreProcessFunction {
5761
PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereEmbeddingPreProcessFunction);
5862
PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_COHERE_MULTI_MODAL_EMBEDDING_INPUT, cohereMultiModalEmbeddingPreProcessFunction);
5963
PRE_PROCESS_FUNCTIONS.put(TEXT_IMAGE_TO_BEDROCK_EMBEDDING_INPUT, multiModalEmbeddingPreProcessFunction);
64+
PRE_PROCESS_FUNCTIONS.put(BEDROCK_NOVA_MULTI_MODAL_EMBEDDING_INPUT, novaMultiModalEmbeddingPreProcessFunction);
6065
PRE_PROCESS_FUNCTIONS.put(TEXT_TO_BEDROCK_NOVA_EMBEDDING_INPUT, bedrockEmbeddingPreProcessFunction);
6166
PRE_PROCESS_FUNCTIONS.put(IMAGE_TO_BEDROCK_NOVA_EMBEDDING_INPUT, imageEmbeddingPreProcessFunction);
6267
PRE_PROCESS_FUNCTIONS.put(VIDEO_TO_BEDROCK_NOVA_EMBEDDING_INPUT, videoEmbeddingPreProcessFunction);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import static org.opensearch.ml.common.utils.StringUtils.convertScriptStringToJsonString;
9+
10+
import java.util.HashMap;
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
15+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
16+
import org.opensearch.ml.common.input.MLInput;
17+
18+
import lombok.extern.log4j.Log4j2;
19+
20+
@Log4j2
21+
public class NovaMultiModalEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {
22+
23+
public NovaMultiModalEmbeddingPreProcessFunction() {
24+
this.returnDirectlyForRemoteInferenceInput = true;
25+
}
26+
27+
@Override
28+
public void validate(MLInput mlInput) {
29+
validateTextDocsInput(mlInput);
30+
List<String> docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs();
31+
if (docs.size() == 0 || (docs.size() == 1 && docs.get(0) == null)) {
32+
throw new IllegalArgumentException("No input provided");
33+
}
34+
}
35+
36+
@Override
37+
public RemoteInferenceInputDataSet process(MLInput mlInput) {
38+
TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset();
39+
String input = inputData.getDocs().get(0);
40+
41+
Map<String, String> parametersMap = new HashMap<>();
42+
String parameterName = detectModalityParameter(input);
43+
parametersMap.put(parameterName, input);
44+
45+
return RemoteInferenceInputDataSet
46+
.builder()
47+
.parameters(convertScriptStringToJsonString(Map.of("parameters", parametersMap)))
48+
.build();
49+
}
50+
51+
private String detectModalityParameter(String input) {
52+
try {
53+
if (input.contains("\"text\"")) {
54+
return "inputText";
55+
}
56+
if (input.contains("\"image\"")) {
57+
return "inputImage";
58+
}
59+
if (input.contains("\"video\"")) {
60+
return "inputVideo";
61+
}
62+
if (input.contains("\"audio\"")) {
63+
return "inputAudio";
64+
}
65+
return "inputText";
66+
} catch (Exception e) {
67+
log.warn("Failed to detect modality from input, defaulting to text: {}", e.getMessage());
68+
return "inputText";
69+
}
70+
}
71+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.connector.functions.preprocess;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.mockito.Mockito.mock;
10+
import static org.mockito.Mockito.when;
11+
12+
import java.util.Arrays;
13+
import java.util.Collections;
14+
import java.util.Map;
15+
16+
import org.junit.Before;
17+
import org.junit.Rule;
18+
import org.junit.Test;
19+
import org.junit.rules.ExpectedException;
20+
import org.opensearch.ml.common.FunctionName;
21+
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
22+
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
23+
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
24+
import org.opensearch.ml.common.input.MLInput;
25+
26+
public class NovaMultiModalEmbeddingPreProcessFunctionTest {
27+
@Rule
28+
public ExpectedException exceptionRule = ExpectedException.none();
29+
30+
NovaMultiModalEmbeddingPreProcessFunction function;
31+
32+
TextSimilarityInputDataSet textSimilarityInputDataSet;
33+
TextDocsInputDataSet textDocsInputDataSet;
34+
RemoteInferenceInputDataSet remoteInferenceInputDataSet;
35+
36+
MLInput textEmbeddingInput;
37+
MLInput textSimilarityInput;
38+
MLInput remoteInferenceInput;
39+
40+
@Before
41+
public void setUp() {
42+
function = new NovaMultiModalEmbeddingPreProcessFunction();
43+
textSimilarityInputDataSet = TextSimilarityInputDataSet.builder().queryText("test").textDocs(Arrays.asList("hello")).build();
44+
textDocsInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("hello")).build();
45+
remoteInferenceInputDataSet = RemoteInferenceInputDataSet.builder().parameters(Map.of("key1", "value1", "key2", "value2")).build();
46+
47+
textEmbeddingInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
48+
textSimilarityInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(textSimilarityInputDataSet).build();
49+
remoteInferenceInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(remoteInferenceInputDataSet).build();
50+
}
51+
52+
@Test
53+
public void process_NullInput() {
54+
exceptionRule.expect(IllegalArgumentException.class);
55+
exceptionRule.expectMessage("Preprocess function input can't be null");
56+
function.apply(null);
57+
}
58+
59+
@Test
60+
public void process_WrongInput() {
61+
exceptionRule.expect(IllegalArgumentException.class);
62+
exceptionRule.expectMessage("This pre_process_function can only support TextDocsInputDataSet");
63+
function.apply(textSimilarityInput);
64+
}
65+
66+
@Test
67+
public void process_TextInput() {
68+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
69+
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
70+
assertEquals(1, dataSet.getParameters().size());
71+
assertEquals("hello", dataSet.getParameters().get("inputText"));
72+
}
73+
74+
@Test
75+
public void process_JsonImageInput() {
76+
TextDocsInputDataSet jsonInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("{\"image\": \"base64data\"}")).build();
77+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(jsonInputDataSet).build();
78+
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
79+
assertEquals(1, dataSet.getParameters().size());
80+
assertEquals("{\"image\": \"base64data\"}", dataSet.getParameters().get("inputImage"));
81+
}
82+
83+
@Test
84+
public void process_VideoInput() {
85+
TextDocsInputDataSet jsonInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("{\"video\": \"videodata\"}")).build();
86+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(jsonInputDataSet).build();
87+
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
88+
assertEquals(1, dataSet.getParameters().size());
89+
assertEquals("{\"video\": \"videodata\"}", dataSet.getParameters().get("inputVideo"));
90+
}
91+
92+
@Test
93+
public void process_AudioInput() {
94+
TextDocsInputDataSet jsonInputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("{\"audio\": \"audiodata\"}")).build();
95+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(jsonInputDataSet).build();
96+
RemoteInferenceInputDataSet dataSet = function.apply(mlInput);
97+
assertEquals(1, dataSet.getParameters().size());
98+
assertEquals("{\"audio\": \"audiodata\"}", dataSet.getParameters().get("inputAudio"));
99+
}
100+
101+
@Test
102+
public void process_EmptyDocs() {
103+
TextDocsInputDataSet mockDataSet = mock(TextDocsInputDataSet.class);
104+
when(mockDataSet.getDocs()).thenReturn(Collections.emptyList());
105+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(mockDataSet).build();
106+
107+
exceptionRule.expect(IllegalArgumentException.class);
108+
exceptionRule.expectMessage("No input provided");
109+
function.apply(mlInput);
110+
}
111+
112+
@Test
113+
public void process_RemoteInferenceInput() {
114+
RemoteInferenceInputDataSet dataSet = function.apply(remoteInferenceInput);
115+
assertEquals(remoteInferenceInputDataSet, dataSet);
116+
}
117+
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
import org.opensearch.ml.engine.processor.ProcessorChain;
5454
import org.opensearch.script.ScriptService;
5555

56+
import com.google.gson.JsonElement;
57+
import com.google.gson.JsonObject;
58+
import com.google.gson.JsonParser;
5659
import com.jayway.jsonpath.JsonPath;
5760

5861
import lombok.extern.log4j.Log4j2;
@@ -72,6 +75,7 @@ public class ConnectorUtils {
7275

7376
private static final AwsV4HttpSigner signer;
7477
public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters";
78+
public static final String BEDROCK_NOVA_MODEL = "amazon.nova-2-multimodal-embeddings-v1:0";
7579

7680
static {
7781
signer = AwsV4HttpSigner.create();
@@ -340,6 +344,13 @@ public static SdkHttpFullRequest buildSdkRequest(
340344
SdkHttpMethod method
341345
) {
342346
String charset = parameters.getOrDefault("charset", "UTF-8");
347+
348+
// Clean empty JSON sections for Bedrock Nova embedding requests
349+
String model = connector.getParameters().get("model");
350+
if (payload != null && model != null && model.equals(BEDROCK_NOVA_MODEL)) {
351+
payload = cleanBedrockNovaRequest(payload);
352+
}
353+
343354
RequestBody requestBody;
344355
if (payload != null) {
345356
requestBody = RequestBody.fromString(payload, Charset.forName(charset));
@@ -480,4 +491,38 @@ public static ConnectorAction createConnectorAction(Connector connector, Connect
480491
.headers(batchPredictAction.get().getHeaders())
481492
.build();
482493
}
494+
495+
private static String cleanBedrockNovaRequest(String json) {
496+
try {
497+
JsonObject root = JsonParser.parseString(json).getAsJsonObject();
498+
JsonObject params = root.getAsJsonObject("singleEmbeddingParams");
499+
if (params == null)
500+
return json;
501+
502+
removeIfNull(params, "text");
503+
removeIfNull(params, "image");
504+
removeIfNull(params, "video");
505+
removeIfNull(params, "audio");
506+
507+
return gson.toJson(root);
508+
} catch (Exception e) {
509+
log.warn("Failed to clean empty JSON sections: {}", e.getMessage());
510+
return json;
511+
}
512+
}
513+
514+
private static void removeIfNull(JsonObject parent, String fieldName) {
515+
JsonObject field = parent.getAsJsonObject(fieldName);
516+
if (field == null)
517+
return;
518+
519+
// Check text field's value or other fields' source.bytes
520+
JsonElement element = "text".equals(fieldName)
521+
? field.get("value")
522+
: (field.getAsJsonObject("source") != null ? field.getAsJsonObject("source").get("bytes") : null);
523+
524+
if (element != null && element.isJsonNull()) {
525+
parent.remove(fieldName);
526+
}
527+
}
483528
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
import static org.junit.Assert.assertNull;
1111
import static org.junit.Assert.assertTrue;
1212
import static org.mockito.ArgumentMatchers.any;
13+
import static org.mockito.Mockito.mock;
1314
import static org.mockito.Mockito.spy;
1415
import static org.mockito.Mockito.when;
1516
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.BATCH_PREDICT_STATUS;
1617
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT;
1718
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
1819
import static org.opensearch.ml.common.utils.StringUtils.gson;
20+
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.BEDROCK_NOVA_MODEL;
1921

2022
import java.io.IOException;
2123
import java.util.ArrayList;
@@ -47,6 +49,7 @@
4749
import com.google.common.collect.ImmutableMap;
4850

4951
import okhttp3.Request;
52+
import software.amazon.awssdk.http.SdkHttpFullRequest;
5053

5154
public class ConnectorUtilsTest {
5255

@@ -1057,6 +1060,51 @@ public void buildSdkRequest_CancelBatchPredictWithEmptyPayload() {
10571060
}
10581061
}
10591062

1063+
@Test
1064+
public void buildSdkRequest_NovaModelCleansJson() throws IOException {
1065+
Connector connector = mock(Connector.class);
1066+
when(connector.getParameters()).thenReturn(Map.of("model", BEDROCK_NOVA_MODEL));
1067+
when(connector.getActionEndpoint("predict", Map.of()))
1068+
.thenReturn("https://bedrock-runtime.us-east-1.amazonaws.com/model/test/invoke");
1069+
when(connector.getDecryptedHeaders()).thenReturn(Map.of("Content-Type", "application/json"));
1070+
1071+
String payloadWithNulls =
1072+
"{\"singleEmbeddingParams\":{\"text\":{\"value\":\"hello\"},\"video\":{\"source\":{\"bytes\":null}},\"audio\":{\"source\":{\"bytes\":null}}}}";
1073+
1074+
SdkHttpFullRequest request = ConnectorUtils
1075+
.buildSdkRequest("predict", connector, Map.of(), payloadWithNulls, software.amazon.awssdk.http.SdkHttpMethod.POST);
1076+
1077+
// Verify request was created successfully
1078+
assertNotNull(request);
1079+
assertTrue(request.contentStreamProvider().isPresent());
1080+
1081+
// Verify the payload was cleaned, null values removed
1082+
String actualPayload = new String(request.contentStreamProvider().get().newStream().readAllBytes());
1083+
String expectedPayload = "{\"singleEmbeddingParams\":{\"text\":{\"value\":\"hello\"}}}";
1084+
assertEquals(expectedPayload, actualPayload);
1085+
}
1086+
1087+
@Test
1088+
public void testBuildSdkRequest_NonNovaModelSkipsCleaning() throws IOException {
1089+
Connector connector = mock(Connector.class);
1090+
when(connector.getParameters()).thenReturn(Map.of("model", "gpt-3.5-turbo"));
1091+
when(connector.getActionEndpoint("predict", Map.of())).thenReturn("https://api.openai.com/v1/chat/completions");
1092+
when(connector.getDecryptedHeaders()).thenReturn(Map.of("Content-Type", "application/json"));
1093+
1094+
String payloadWithNulls = "{\"video\":{\"source\":{\"bytes\":null}}}";
1095+
1096+
SdkHttpFullRequest request = ConnectorUtils
1097+
.buildSdkRequest("predict", connector, Map.of(), payloadWithNulls, software.amazon.awssdk.http.SdkHttpMethod.POST);
1098+
1099+
// Verify request was created successfully
1100+
assertNotNull(request);
1101+
assertTrue(request.contentStreamProvider().isPresent());
1102+
1103+
// Verify the payload was not cleaned, null values preserved
1104+
String actualPayload = new String(request.contentStreamProvider().get().newStream().readAllBytes());
1105+
assertEquals(payloadWithNulls, actualPayload);
1106+
}
1107+
10601108
@Test
10611109
public void createConnectorAction_WithEmptyParameters() {
10621110
Connector connector = HttpConnector

0 commit comments

Comments
 (0)