Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
if (response instanceof String && isJson((String) response)) {
Map<String, Object> data = StringUtils.fromJson((String) response, ML_MAP_RESPONSE_KEY);
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build());
} else if (response instanceof Map) {
modelTensors.add(ModelTensor.builder().name("response").dataAsMap((Map<String, ?>) response).build());
} else {
Map<String, Object> map = new HashMap<>();
map.put("response", response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.List;
import java.util.Map;

import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down Expand Up @@ -289,4 +290,14 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
}

@Override
public String toString() {
try {
return this.toXContent(JsonXContent.contentBuilder(), null).toString();
} catch (IOException e) {
throw new IllegalArgumentException("Can't convert ModelTensor to string", e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.ArrayList;
import java.util.List;

import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
Expand Down Expand Up @@ -102,4 +103,13 @@ public static ModelTensorOutput parse(XContentParser parser) throws IOException

return ModelTensorOutput.builder().mlModelOutputs(mlModelOutputs).build();
}

@Override
public String toString() {
try {
return this.toXContent(JsonXContent.contentBuilder(), null).toString();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.List;

import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -171,4 +172,13 @@ public static ModelTensors parse(XContentParser parser) throws IOException {
modelTensors.setStatusCode(statusCode);
return modelTensors;
}

@Override
public String toString() {
try {
return this.toXContent(JsonXContent.contentBuilder(), null).toString();
} catch (IOException e) {
throw new IllegalArgumentException("Can't convert ModelTensor to string", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
import org.json.JSONObject;
import org.opensearch.OpenSearchParseException;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
Expand Down Expand Up @@ -77,7 +80,6 @@ public class StringUtils {

public static final String SAFE_INPUT_DESCRIPTION = "can only contain letters, numbers, spaces, and basic punctuation (.,!?():@-_'/\")";

public static final Gson gson = new Gson();
public static final Gson PLAIN_NUMBER_GSON = new GsonBuilder()
.serializeNulls()
.registerTypeAdapter(Float.class, new PlainFloatAdapter())
Expand All @@ -86,6 +88,15 @@ public class StringUtils {
.registerTypeAdapter(double.class, new PlainDoubleAdapter())
.create();

public static final Gson gson;
static {
gson = new GsonBuilder()
.disableHtmlEscaping()
.registerTypeAdapter(ModelTensor.class, new ToStringTypeAdapter<>(ModelTensor.class))
.registerTypeAdapter(ModelTensorOutput.class, new ToStringTypeAdapter<>(ModelTensorOutput.class))
.registerTypeAdapter(ModelTensors.class, new ToStringTypeAdapter<>(ModelTensors.class))
.create();
}
public static final String TO_STRING_FUNCTION_NAME = ".toString()";

public static final ObjectMapper MAPPER = new ObjectMapper();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.utils;

import java.io.IOException;

import com.google.gson.TypeAdapter;
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonWriter;

public class ToStringTypeAdapter<T> extends TypeAdapter<T> {

private final Class<T> clazz;

public ToStringTypeAdapter(Class<T> clazz) {
this.clazz = clazz;
}

@Override
public void write(JsonWriter out, T value) throws IOException {
if (value == null) {
out.nullValue();
return;
}
String json = value.toString();
out.jsonValue(json);
}

@Override
public T read(JsonReader in) throws IOException {
throw new UnsupportedOperationException("Deserialization not supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,38 @@ public static String getToolName(MLToolSpec toolSpec) {
return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
}

/**
* Converts various types of tool output into a standardized ModelTensor format.
* The conversion logic depends on the type of input:
* <ul>
* <li>For Map inputs: directly uses the map as data</li>
* <li>For List inputs: wraps the list in a map with "output" as the key</li>
* <li>For other types: converts to JSON string and attempts to parse as map,
* if parsing fails, wraps the original output in a map with "output" as the key</li>
* </ul>
*
* @param output The output object to be converted. Can be a Map, List, or any other object
* @param outputKey The key/name to be assigned to the resulting ModelTensor
* @return A ModelTensor containing the formatted output data
*/
public static ModelTensor convertOutputToModelTensor(Object output, String outputKey) {
ModelTensor modelTensor;
if (output instanceof Map) {
modelTensor = ModelTensor.builder().name(outputKey).dataAsMap((Map) output).build();
} else if (output instanceof List) {
Map<String, Object> resultMap = new HashMap<>();
resultMap.put("output", output);
Comment on lines +235 to +238
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we put the output in output field here, but above we just return as is if it is a map, so output key may or may not be there? is this expected

modelTensor = ModelTensor.builder().name(outputKey).dataAsMap(resultMap).build();
} else {
String outputJson = StringUtils.toJson(output);
Map<String, Object> resultMap;
if (StringUtils.isJson(outputJson)) {
resultMap = StringUtils.fromJson(outputJson, "output");
} else {
resultMap = Map.of("output", output);
}
modelTensor = ModelTensor.builder().name(outputKey).dataAsMap(resultMap).build();
}
return modelTensor;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,20 @@ public void parseResponse_NonJsonString() throws IOException {
Assert.assertEquals("test output", modelTensors.get(0).getDataAsMap().get("response"));
}

@Test
public void parseResponse_MapResponse() throws IOException {
HttpConnector connector = createHttpConnector();
Map<String, Object> responseMap = new HashMap<>();
responseMap.put("key1", "value1");
responseMap.put("key2", "value2");
List<ModelTensor> modelTensors = new ArrayList<>();

connector.parseResponse(responseMap, modelTensors, false);
Assert.assertEquals(1, modelTensors.size());
Assert.assertEquals("response", modelTensors.get(0).getName());
Assert.assertEquals(responseMap, modelTensors.get(0).getDataAsMap());
}

@Test
public void fillNullParameters() {
HttpConnector connector = createHttpConnector();
Expand Down Expand Up @@ -442,4 +456,33 @@ public void parse_WithTenantId() throws IOException {
Assert.assertEquals("test_tenant", connector.getTenantId());
}

@Test
public void testParseResponse_MapResponse() throws IOException {
HttpConnector connector = createHttpConnector();

Map<String, Object> responseMap = new HashMap<>();
responseMap.put("result", "success");
responseMap.put("data", Arrays.asList("item1", "item2"));

List<ModelTensor> modelTensors = new ArrayList<>();
connector.parseResponse(responseMap, modelTensors, false);

Assert.assertEquals(1, modelTensors.size());
Assert.assertEquals("response", modelTensors.get(0).getName());
Assert.assertEquals(responseMap, modelTensors.get(0).getDataAsMap());
}

@Test
public void testParseResponse_NonStringNonMapResponse() throws IOException {
HttpConnector connector = createHttpConnector();

Integer numericResponse = 42;
List<ModelTensor> modelTensors = new ArrayList<>();
connector.parseResponse(numericResponse, modelTensors, false);

Assert.assertEquals(1, modelTensors.size());
Assert.assertEquals("response", modelTensors.get(0).getName());
Assert.assertEquals(42, modelTensors.get(0).getDataAsMap().get("response"));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ public void parse_SkipIrrelevantFields() throws IOException {
assertArrayEquals(new long[] { 1, 3 }, modelTensor.getShape());
}

@Test
public void test_ToString() {
String result = modelTensorOutput.toString();
String expected =
"{\"inference_results\":[{\"output\":[{\"name\":\"test\",\"data_type\":\"FLOAT32\",\"shape\":[1,3],\"data\":[1.0,2.0,3.0],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}]}";
assertEquals(expected, result);
}

private void readInputStream(ModelTensorOutput input, Consumer<ModelTensorOutput> verify) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
input.writeTo(bytesStreamOutput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
package org.opensearch.ml.common.output.model;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.spy;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

import java.io.IOException;
Expand Down Expand Up @@ -122,4 +125,23 @@ public void test_NullDataType() {
.byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 }))
.build();
}

@Test
public void test_ToString() {
String result = modelTensor.toString();
String expected =
"{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"},\"result\":\"test result\",\"dataAsMap\":{\"key1\":\"test value1\",\"key2\":\"test value2\"}}";
assertEquals(expected, result);
}

@Test
public void test_ToString_ThrowsException() throws IOException {
ModelTensor spyTensor = spy(modelTensor);
doThrow(new IOException("Mock IOException")).when(spyTensor).toXContent(any(XContentBuilder.class), any());

exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Can't convert ModelTensor to string");

spyTensor.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.spy;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

import java.io.IOException;
Expand Down Expand Up @@ -274,4 +277,23 @@ public void parse_SkipIrrelevantFields() throws IOException {
ModelTensor modelTensor = parsedTensors.getMlModelTensors().get(0);
assertEquals("test_tensor", modelTensor.getName());
}

@Test
public void test_ToString() {
String result = modelTensors.toString();
String expected =
"{\"output\":[{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}";
assertEquals(expected, result);
}

@Test
public void test_ToString_ThrowsException() throws IOException {
ModelTensors spyTensors = spy(modelTensors);
doThrow(new IOException("Mock IOException")).when(spyTensors).toXContent(any(XContentBuilder.class), any());

exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Can't convert ModelTensor to string");

spyTensors.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
import org.junit.Test;
import org.opensearch.OpenSearchParseException;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;

import com.google.gson.JsonElement;
import com.google.gson.TypeAdapter;
Expand Down Expand Up @@ -244,6 +248,48 @@ public void testGetErrorMessageWhenHiddenNull() {
* in the values. Verifies that the method correctly extracts the prefixes of the toString()
* method calls.
*/
@Test
public void testCollectToStringPrefixes() {
Map<String, String> map = new HashMap<>();
map.put("key1", "${parameters.tensor.toString()}");
map.put("key2", "${parameters.output.toString()}");
map.put("key3", "normal value");

List<String> prefixes = StringUtils.collectToStringPrefixes(map);

assertEquals(2, prefixes.size());
assertTrue(prefixes.contains("tensor"));
assertTrue(prefixes.contains("output"));
}

@Test
public void test_GsonTypeAdapters() {
// Test ModelTensor serialization
ModelTensor tensor = ModelTensor
.builder()
.name("test_tensor")
.data(new Number[] { 1, 2, 3 })
.dataType(MLResultDataType.INT32)
.build();

String tensorJson = StringUtils.gson.toJson(tensor);
assertEquals(tensor.toString(), tensorJson);

// Test ModelTensorOutput serialization
List<ModelTensors> outputs = new ArrayList<>();
outputs.add(ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build());
ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(outputs).build();

String outputJson = StringUtils.gson.toJson(output);
assertEquals(output.toString(), outputJson);

// Test ModelTensors serialization
ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build();

String tensorsJson = StringUtils.gson.toJson(tensors);
assertEquals(tensors.toString(), tensorsJson);
}

@Test
public void testGetToStringPrefix() {
Map<String, String> parameters = new HashMap<>();
Expand Down
Loading
Loading