Skip to content

Commit f9f3764

Browse files
committed
add more unit test
Signed-off-by: Yaliang Wu <[email protected]>
1 parent 896b4d2 commit f9f3764

File tree

4 files changed

+116
-7
lines changed

4 files changed

+116
-7
lines changed

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ public String toString() {
296296
try {
297297
return this.toXContent(JsonXContent.contentBuilder(), null).toString();
298298
} catch (IOException e) {
299-
throw new RuntimeException(e);
299+
throw new IllegalArgumentException("Can't convert ModelTensor to string", e);
300300
}
301301
}
302302

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ public String toString() {
178178
try {
179179
return this.toXContent(JsonXContent.contentBuilder(), null).toString();
180180
} catch (IOException e) {
181-
throw new RuntimeException(e);
181+
throw new IllegalArgumentException("Can't convert ModelTensor to string", e);
182182
}
183183
}
184184
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,9 @@ private static MLInput escapeMLInput(MLInput mlInput) {
167167
}
168168

169169
public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet inputData) {
170+
if (inputData.getParameters() == null) {
171+
return;
172+
}
170173
Map<String, String> newParameters = new HashMap<>();
171174
String noEscapeParams = inputData.getParameters().get(NO_ESCAPE_PARAMS);
172175
Set<String> noEscapParamSet = new HashSet<>();

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -423,17 +423,69 @@ public void testEscapeRemoteInferenceInputData_WithNoEscapeParams() {
423423
params.put("key1", inputKey1);
424424
params.put("key2", "test value");
425425
params.put("key3", inputKey3);
426-
params.put("NO_ESCAPE_PARAMS", "key1,key3");
426+
params.put("no_escape_params", "key1,key3");
427427

428428
RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build();
429429

430430
ConnectorUtils.escapeRemoteInferenceInputData(inputData);
431431

432-
String expectedKey1 = "hello \\\"world\\\"";
433-
String expectedKey3 = "special \\\"chars\\\"";
434-
assertEquals(expectedKey1, inputData.getParameters().get("key1"));
432+
assertEquals(inputKey1, inputData.getParameters().get("key1"));
435433
assertEquals("test value", inputData.getParameters().get("key2"));
436-
assertEquals(expectedKey3, inputData.getParameters().get("key3"));
434+
assertEquals(inputKey3, inputData.getParameters().get("key3"));
435+
}
436+
437+
@Test
438+
public void testEscapeRemoteInferenceInputData_NullValue() {
439+
Map<String, String> params = new HashMap<>();
440+
params.put("key1", null);
441+
RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build();
442+
443+
ConnectorUtils.escapeRemoteInferenceInputData(inputData);
444+
445+
assertNull(inputData.getParameters().get("key1"));
446+
}
447+
448+
@Test
449+
public void testEscapeRemoteInferenceInputData_JsonValue() {
450+
Map<String, String> params = new HashMap<>();
451+
params.put("key1", "{\"test\": \"value\"}");
452+
RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build();
453+
454+
ConnectorUtils.escapeRemoteInferenceInputData(inputData);
455+
456+
assertEquals("{\"test\": \"value\"}", inputData.getParameters().get("key1"));
457+
}
458+
459+
@Test
460+
public void testEscapeRemoteInferenceInputData_EscapeValue() {
461+
Map<String, String> params = new HashMap<>();
462+
params.put("key1", "test\"value");
463+
RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build();
464+
465+
ConnectorUtils.escapeRemoteInferenceInputData(inputData);
466+
467+
assertEquals("test\\\"value", inputData.getParameters().get("key1"));
468+
}
469+
470+
@Test
471+
public void testEscapeRemoteInferenceInputData_NoEscapeParam() {
472+
Map<String, String> params = new HashMap<>();
473+
params.put("key1", "test\"value");
474+
params.put("no_escape_params", "key1");
475+
RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(params).build();
476+
477+
ConnectorUtils.escapeRemoteInferenceInputData(inputData);
478+
479+
assertEquals("test\"value", inputData.getParameters().get("key1"));
480+
}
481+
482+
@Test
483+
public void testEscapeRemoteInferenceInputData_NullParameters() {
484+
RemoteInferenceInputDataSet inputData = RemoteInferenceInputDataSet.builder().parameters(null).build();
485+
486+
ConnectorUtils.escapeRemoteInferenceInputData(inputData);
487+
488+
assertNull(inputData.getParameters());
437489
}
438490

439491
@Test
@@ -538,4 +590,58 @@ public void processOutput_WithResponseFilterOnly() throws IOException {
538590
assertEquals(1, tensors.getMlModelTensors().size());
539591
assertEquals("response", tensors.getMlModelTensors().get(0).getName());
540592
}
593+
594+
@Test
595+
public void processOutput_ScriptReturnModelTensor_WithJsonResponse() throws IOException {
596+
String postprocessResult = "{\"name\":\"test\",\"data\":[1,2,3]}";
597+
when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult));
598+
599+
ConnectorAction predictAction = ConnectorAction
600+
.builder()
601+
.actionType(PREDICT)
602+
.method("POST")
603+
.url("http://test.com/mock")
604+
.requestBody("{\"input\": \"${parameters.input}\"}")
605+
.postProcessFunction("custom_script")
606+
.build();
607+
Connector connector = HttpConnector
608+
.builder()
609+
.name("test connector")
610+
.version("1")
611+
.protocol("http")
612+
.actions(Arrays.asList(predictAction))
613+
.build();
614+
String modelResponse = "{\"result\":\"test\"}";
615+
616+
ModelTensors tensors = ConnectorUtils
617+
.processOutput(PREDICT.name(), modelResponse, connector, scriptService, ImmutableMap.of(), null);
618+
619+
assertEquals(1, tensors.getMlModelTensors().size());
620+
}
621+
622+
@Test
623+
public void processOutput_WithProcessorChain_StringOutput() throws IOException {
624+
ConnectorAction predictAction = ConnectorAction
625+
.builder()
626+
.actionType(PREDICT)
627+
.method("POST")
628+
.url("http://test.com/mock")
629+
.requestBody("{\"input\": \"${parameters.input}\"}")
630+
.build();
631+
Map<String, String> parameters = new HashMap<>();
632+
parameters.put("processor_configs", "[{\"type\":\"test_processor\"}]");
633+
Connector connector = HttpConnector
634+
.builder()
635+
.name("test connector")
636+
.version("1")
637+
.protocol("http")
638+
.actions(Arrays.asList(predictAction))
639+
.build();
640+
String modelResponse = "{\"result\":\"test response\"}";
641+
642+
ModelTensors tensors = ConnectorUtils.processOutput(PREDICT.name(), modelResponse, connector, scriptService, parameters, null);
643+
644+
assertEquals(1, tensors.getMlModelTensors().size());
645+
assertEquals("response", tensors.getMlModelTensors().get(0).getName());
646+
}
541647
}

0 commit comments

Comments
 (0)