Skip to content

Commit a5f0f5b

Browse files
committed
add unit test
Signed-off-by: Yaliang Wu <[email protected]>
1 parent 13e2f43 commit a5f0f5b

File tree

13 files changed

+524
-1
lines changed

13 files changed

+524
-1
lines changed

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,20 @@ public void parseResponse_NonJsonString() throws IOException {
304304
Assert.assertEquals("test output", modelTensors.get(0).getDataAsMap().get("response"));
305305
}
306306

307+
@Test
308+
public void parseResponse_MapResponse() throws IOException {
309+
HttpConnector connector = createHttpConnector();
310+
Map<String, Object> responseMap = new HashMap<>();
311+
responseMap.put("key1", "value1");
312+
responseMap.put("key2", "value2");
313+
List<ModelTensor> modelTensors = new ArrayList<>();
314+
315+
connector.parseResponse(responseMap, modelTensors, false);
316+
Assert.assertEquals(1, modelTensors.size());
317+
Assert.assertEquals("response", modelTensors.get(0).getName());
318+
Assert.assertEquals(responseMap, modelTensors.get(0).getDataAsMap());
319+
}
320+
307321
@Test
308322
public void fillNullParameters() {
309323
HttpConnector connector = createHttpConnector();
@@ -442,4 +456,33 @@ public void parse_WithTenantId() throws IOException {
442456
Assert.assertEquals("test_tenant", connector.getTenantId());
443457
}
444458

459+
@Test
460+
public void testParseResponse_MapResponse() throws IOException {
461+
HttpConnector connector = createHttpConnector();
462+
463+
Map<String, Object> responseMap = new HashMap<>();
464+
responseMap.put("result", "success");
465+
responseMap.put("data", Arrays.asList("item1", "item2"));
466+
467+
List<ModelTensor> modelTensors = new ArrayList<>();
468+
connector.parseResponse(responseMap, modelTensors, false);
469+
470+
Assert.assertEquals(1, modelTensors.size());
471+
Assert.assertEquals("response", modelTensors.get(0).getName());
472+
Assert.assertEquals(responseMap, modelTensors.get(0).getDataAsMap());
473+
}
474+
475+
@Test
476+
public void testParseResponse_NonStringNonMapResponse() throws IOException {
477+
HttpConnector connector = createHttpConnector();
478+
479+
Integer numericResponse = 42;
480+
List<ModelTensor> modelTensors = new ArrayList<>();
481+
connector.parseResponse(numericResponse, modelTensors, false);
482+
483+
Assert.assertEquals(1, modelTensors.size());
484+
Assert.assertEquals("response", modelTensors.get(0).getName());
485+
Assert.assertEquals(42, modelTensors.get(0).getDataAsMap().get("response"));
486+
}
487+
445488
}

common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ public void parse_SkipIrrelevantFields() throws IOException {
170170
assertArrayEquals(new long[] { 1, 3 }, modelTensor.getShape());
171171
}
172172

173+
@Test
174+
public void test_ToString() {
175+
String result = modelTensorOutput.toString();
176+
String expected =
177+
"{\"inference_results\":[{\"output\":[{\"name\":\"test\",\"data_type\":\"FLOAT32\",\"shape\":[1,3],\"data\":[1.0,2.0,3.0],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}]}";
178+
assertEquals(expected, result);
179+
}
180+
173181
private void readInputStream(ModelTensorOutput input, Consumer<ModelTensorOutput> verify) throws IOException {
174182
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
175183
input.writeTo(bytesStreamOutput);

common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,12 @@ public void test_NullDataType() {
122122
.byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 }))
123123
.build();
124124
}
125+
126+
@Test
127+
public void test_ToString() {
128+
String result = modelTensor.toString();
129+
String expected =
130+
"{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"},\"result\":\"test result\",\"dataAsMap\":{\"key1\":\"test value1\",\"key2\":\"test value2\"}}";
131+
assertEquals(expected, result);
132+
}
125133
}

common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,12 @@ public void parse_SkipIrrelevantFields() throws IOException {
274274
ModelTensor modelTensor = parsedTensors.getMlModelTensors().get(0);
275275
assertEquals("test_tensor", modelTensor.getName());
276276
}
277+
278+
@Test
279+
public void test_ToString() {
280+
String result = modelTensors.toString();
281+
String expected =
282+
"{\"output\":[{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}";
283+
assertEquals(expected, result);
284+
}
277285
}

common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
import org.junit.Test;
3232
import org.opensearch.OpenSearchParseException;
3333
import org.opensearch.action.ActionRequestValidationException;
34+
import org.opensearch.ml.common.output.model.MLResultDataType;
35+
import org.opensearch.ml.common.output.model.ModelTensor;
36+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
37+
import org.opensearch.ml.common.output.model.ModelTensors;
3438

3539
import com.google.gson.JsonElement;
3640
import com.google.gson.TypeAdapter;
@@ -244,6 +248,48 @@ public void testGetErrorMessageWhenHiddenNull() {
244248
* in the values. Verifies that the method correctly extracts the prefixes of the toString()
245249
* method calls.
246250
*/
251+
@Test
252+
public void testCollectToStringPrefixes() {
253+
Map<String, String> map = new HashMap<>();
254+
map.put("key1", "${parameters.tensor.toString()}");
255+
map.put("key2", "${parameters.output.toString()}");
256+
map.put("key3", "normal value");
257+
258+
List<String> prefixes = StringUtils.collectToStringPrefixes(map);
259+
260+
assertEquals(2, prefixes.size());
261+
assertTrue(prefixes.contains("tensor"));
262+
assertTrue(prefixes.contains("output"));
263+
}
264+
265+
@Test
266+
public void test_GsonTypeAdapters() {
267+
// Test ModelTensor serialization
268+
ModelTensor tensor = ModelTensor
269+
.builder()
270+
.name("test_tensor")
271+
.data(new Number[] { 1, 2, 3 })
272+
.dataType(MLResultDataType.INT32)
273+
.build();
274+
275+
String tensorJson = StringUtils.gson.toJson(tensor);
276+
assertEquals(tensor.toString(), tensorJson);
277+
278+
// Test ModelTensorOutput serialization
279+
List<ModelTensors> outputs = new ArrayList<>();
280+
outputs.add(ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build());
281+
ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(outputs).build();
282+
283+
String outputJson = StringUtils.gson.toJson(output);
284+
assertEquals(output.toString(), outputJson);
285+
286+
// Test ModelTensors serialization
287+
ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build();
288+
289+
String tensorsJson = StringUtils.gson.toJson(tensors);
290+
assertEquals(tensors.toString(), tensorsJson);
291+
}
292+
247293
@Test
248294
public void testGetToStringPrefix() {
249295
Map<String, String> parameters = new HashMap<>();
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.utils;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertThrows;
10+
11+
import java.io.IOException;
12+
import java.io.StringWriter;
13+
14+
import org.junit.Before;
15+
import org.junit.Test;
16+
import org.opensearch.ml.common.output.model.MLResultDataType;
17+
import org.opensearch.ml.common.output.model.ModelTensor;
18+
19+
import com.google.gson.stream.JsonWriter;
20+
21+
public class ToStringTypeAdapterTest {
22+
23+
private ToStringTypeAdapter<ModelTensor> adapter;
24+
private ModelTensor modelTensor;
25+
26+
@Before
27+
public void setUp() {
28+
adapter = new ToStringTypeAdapter<>(ModelTensor.class);
29+
modelTensor = ModelTensor.builder().name("test_tensor").data(new Number[] { 1, 2, 3 }).dataType(MLResultDataType.INT32).build();
30+
}
31+
32+
@Test
33+
public void test_Write_ValidObject() throws IOException {
34+
StringWriter stringWriter = new StringWriter();
35+
JsonWriter jsonWriter = new JsonWriter(stringWriter);
36+
37+
adapter.write(jsonWriter, modelTensor);
38+
39+
String result = stringWriter.toString();
40+
assertEquals(modelTensor.toString(), result);
41+
}
42+
43+
@Test
44+
public void test_Write_NullObject() throws IOException {
45+
StringWriter stringWriter = new StringWriter();
46+
JsonWriter jsonWriter = new JsonWriter(stringWriter);
47+
48+
adapter.write(jsonWriter, null);
49+
50+
String result = stringWriter.toString();
51+
assertEquals("null", result);
52+
}
53+
54+
@Test
55+
public void test_Read_ThrowsUnsupportedOperationException() {
56+
assertThrows(UnsupportedOperationException.class, () -> { adapter.read(null); });
57+
}
58+
}

common/src/test/java/org/opensearch/ml/common/utils/ToolUtilsTest.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import org.junit.Test;
2121
import org.opensearch.ml.common.agent.MLToolSpec;
22+
import org.opensearch.ml.common.output.model.ModelTensor;
2223

2324
public class ToolUtilsTest {
2425

@@ -343,4 +344,51 @@ public void testExtractInputParameters_NoInputParameter() {
343344
assertEquals("value1", result.get("param1"));
344345
assertEquals("value2", result.get("param2"));
345346
}
347+
348+
@Test
349+
public void testConvertOutputToModelTensor_WithMap() {
350+
Map<String, Object> mapOutput = Map.of("key1", "value1", "key2", "value2");
351+
String outputKey = "test_output";
352+
353+
ModelTensor result = ToolUtils.convertOutputToModelTensor(mapOutput, outputKey);
354+
355+
assertEquals(outputKey, result.getName());
356+
assertEquals(mapOutput, result.getDataAsMap());
357+
}
358+
359+
@Test
360+
public void testConvertOutputToModelTensor_WithList() {
361+
List<String> listOutput = List.of("item1", "item2", "item3");
362+
String outputKey = "test_output";
363+
364+
ModelTensor result = ToolUtils.convertOutputToModelTensor(listOutput, outputKey);
365+
366+
assertEquals(outputKey, result.getName());
367+
Map<String, Object> expectedMap = Map.of("output", listOutput);
368+
assertEquals(expectedMap, result.getDataAsMap());
369+
}
370+
371+
@Test
372+
public void testConvertOutputToModelTensor_WithJsonString() {
373+
String jsonOutput = "{\"key1\":\"value1\",\"key2\":\"value2\"}";
374+
String outputKey = "test_output";
375+
376+
ModelTensor result = ToolUtils.convertOutputToModelTensor(jsonOutput, outputKey);
377+
378+
assertEquals(outputKey, result.getName());
379+
assertTrue(result.getDataAsMap().containsKey("key1"));
380+
assertTrue(result.getDataAsMap().containsKey("key2"));
381+
}
382+
383+
@Test
384+
public void testConvertOutputToModelTensor_WithNonJsonString() {
385+
String stringOutput = "simple string output";
386+
String outputKey = "test_output";
387+
388+
ModelTensor result = ToolUtils.convertOutputToModelTensor(stringOutput, outputKey);
389+
390+
assertEquals(outputKey, result.getName());
391+
Map<String, Object> expectedMap = Map.of("output", stringOutput);
392+
assertEquals(expectedMap, result.getDataAsMap());
393+
}
346394
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,12 @@ public void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Obje
161161
previousStepListener = nextStepListener;
162162
}
163163
}
164-
firstTool.run(firstToolExecuteParams, firstStepListener);
164+
// firstTool.run(firstToolExecuteParams, firstStepListener);
165+
if (toolSpecs.size() == 1) {
166+
firstTool.run(firstToolExecuteParams, listener);
167+
} else {
168+
firstTool.run(firstToolExecuteParams, firstStepListener);
169+
}
165170
}
166171

167172
@VisibleForTesting

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,4 +491,38 @@ public void testRunWithUpdateFailure() {
491491
assertNotNull(additionalInfo.get(SECOND_TOOL + ".output"));
492492
}
493493

494+
@Test
495+
public void testRunWithReturnDataAsMap() {
496+
final Map<String, String> params = new HashMap<>();
497+
Map<String, Object> toolOutput = Map.of("key1", "value1", "key2", "value2");
498+
MLToolSpec toolSpec = MLToolSpec
499+
.builder()
500+
.name(FIRST_TOOL)
501+
.type(FIRST_TOOL)
502+
.includeOutputInAgentResponse(true)
503+
.parameters(Map.of("return_data_as_map", "true"))
504+
.build();
505+
final MLAgent mlAgent = MLAgent.builder().name("TestAgent").type(MLAgentType.FLOW.name()).tools(Arrays.asList(toolSpec)).build();
506+
507+
doAnswer(invocation -> {
508+
ActionListener<Object> listener = invocation.getArgument(1);
509+
listener.onResponse(toolOutput);
510+
return null;
511+
}).when(firstTool).run(anyMap(), any());
512+
513+
mlFlowAgentRunner.run(mlAgent, params, agentActionListener);
514+
Mockito.verify(agentActionListener).onResponse(objectCaptor.capture());
515+
Object capturedValue = objectCaptor.getValue();
516+
517+
if (capturedValue instanceof List) {
518+
List<ModelTensor> agentOutput = (List<ModelTensor>) capturedValue;
519+
assertEquals(1, agentOutput.size());
520+
assertEquals(FIRST_TOOL + ".output", agentOutput.get(0).getName());
521+
assertEquals(toolOutput, agentOutput.get(0).getDataAsMap());
522+
} else if (capturedValue instanceof Map) {
523+
Map<String, Object> agentOutput = (Map<String, Object>) capturedValue;
524+
assertEquals(toolOutput, agentOutput);
525+
}
526+
}
527+
494528
}

0 commit comments

Comments
 (0)