Skip to content

Commit 13e2f43

Browse files
committed
add processor chain and add support for model and tool
Signed-off-by: Yaliang Wu <[email protected]>
1 parent 55c1e90 commit 13e2f43

File tree

18 files changed

+2981
-26
lines changed

18 files changed

+2981
-26
lines changed

common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
111111
if (response instanceof String && isJson((String) response)) {
112112
Map<String, Object> data = StringUtils.fromJson((String) response, ML_MAP_RESPONSE_KEY);
113113
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build());
114+
} else if (response instanceof Map) {
115+
modelTensors.add(ModelTensor.builder().name("response").dataAsMap((Map<String, ?>) response).build());
114116
} else {
115117
Map<String, Object> map = new HashMap<>();
116118
map.put("response", response);

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919
import java.util.Map;
2020

21+
import org.opensearch.common.xcontent.json.JsonXContent;
2122
import org.opensearch.core.common.io.stream.StreamInput;
2223
import org.opensearch.core.common.io.stream.StreamOutput;
2324
import org.opensearch.core.common.io.stream.Writeable;
@@ -289,4 +290,14 @@ public void writeTo(StreamOutput out) throws IOException {
289290
out.writeBoolean(false);
290291
}
291292
}
293+
294+
@Override
295+
public String toString() {
296+
try {
297+
return this.toXContent(JsonXContent.contentBuilder(), null).toString();
298+
} catch (IOException e) {
299+
throw new RuntimeException(e);
300+
}
301+
}
302+
292303
}

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.ArrayList;
1212
import java.util.List;
1313

14+
import org.opensearch.common.xcontent.json.JsonXContent;
1415
import org.opensearch.core.common.io.stream.StreamInput;
1516
import org.opensearch.core.common.io.stream.StreamOutput;
1617
import org.opensearch.core.xcontent.XContentBuilder;
@@ -102,4 +103,13 @@ public static ModelTensorOutput parse(XContentParser parser) throws IOException
102103

103104
return ModelTensorOutput.builder().mlModelOutputs(mlModelOutputs).build();
104105
}
106+
107+
@Override
108+
public String toString() {
109+
try {
110+
return this.toXContent(JsonXContent.contentBuilder(), null).toString();
111+
} catch (IOException e) {
112+
throw new RuntimeException(e);
113+
}
114+
}
105115
}

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.List;
1414

1515
import org.opensearch.common.io.stream.BytesStreamOutput;
16+
import org.opensearch.common.xcontent.json.JsonXContent;
1617
import org.opensearch.core.common.bytes.BytesReference;
1718
import org.opensearch.core.common.io.stream.StreamInput;
1819
import org.opensearch.core.common.io.stream.StreamOutput;
@@ -171,4 +172,13 @@ public static ModelTensors parse(XContentParser parser) throws IOException {
171172
modelTensors.setStatusCode(statusCode);
172173
return modelTensors;
173174
}
175+
176+
@Override
177+
public String toString() {
178+
try {
179+
return this.toXContent(JsonXContent.contentBuilder(), null).toString();
180+
} catch (IOException e) {
181+
throw new RuntimeException(e);
182+
}
183+
}
174184
}

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
import org.json.JSONObject;
3636
import org.opensearch.OpenSearchParseException;
3737
import org.opensearch.action.ActionRequestValidationException;
38+
import org.opensearch.ml.common.output.model.ModelTensor;
39+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
40+
import org.opensearch.ml.common.output.model.ModelTensors;
3841

3942
import com.fasterxml.jackson.core.JsonProcessingException;
4043
import com.fasterxml.jackson.databind.JsonNode;
@@ -77,7 +80,6 @@ public class StringUtils {
7780

7881
public static final String SAFE_INPUT_DESCRIPTION = "can only contain letters, numbers, spaces, and basic punctuation (.,!?():@-_'/\")";
7982

80-
public static final Gson gson = new Gson();
8183
public static final Gson PLAIN_NUMBER_GSON = new GsonBuilder()
8284
.serializeNulls()
8385
.registerTypeAdapter(Float.class, new PlainFloatAdapter())
@@ -86,6 +88,15 @@ public class StringUtils {
8688
.registerTypeAdapter(double.class, new PlainDoubleAdapter())
8789
.create();
8890

91+
public static final Gson gson;
92+
static {
93+
gson = new GsonBuilder()
94+
.disableHtmlEscaping()
95+
.registerTypeAdapter(ModelTensor.class, new ToStringTypeAdapter<>(ModelTensor.class))
96+
.registerTypeAdapter(ModelTensorOutput.class, new ToStringTypeAdapter<>(ModelTensorOutput.class))
97+
.registerTypeAdapter(ModelTensors.class, new ToStringTypeAdapter<>(ModelTensors.class))
98+
.create();
99+
}
89100
public static final String TO_STRING_FUNCTION_NAME = ".toString()";
90101

91102
public static final ObjectMapper MAPPER = new ObjectMapper();
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.utils;
7+
8+
import java.io.IOException;
9+
10+
import com.google.gson.TypeAdapter;
11+
import com.google.gson.stream.JsonReader;
12+
import com.google.gson.stream.JsonWriter;
13+
14+
public class ToStringTypeAdapter<T> extends TypeAdapter<T> {
15+
16+
private final Class<T> clazz;
17+
18+
public ToStringTypeAdapter(Class<T> clazz) {
19+
this.clazz = clazz;
20+
}
21+
22+
@Override
23+
public void write(JsonWriter out, T value) throws IOException {
24+
if (value == null) {
25+
out.nullValue();
26+
return;
27+
}
28+
String json = value.toString();
29+
out.jsonValue(json);
30+
}
31+
32+
@Override
33+
public T read(JsonReader in) throws IOException {
34+
throw new UnsupportedOperationException("Deserialization not supported");
35+
}
36+
}

common/src/main/java/org/opensearch/ml/common/utils/ToolUtils.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,4 +215,38 @@ public static String getToolName(MLToolSpec toolSpec) {
215215
return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
216216
}
217217

218+
/**
219+
* Converts various types of tool output into a standardized ModelTensor format.
220+
* The conversion logic depends on the type of input:
221+
* <ul>
222+
* <li>For Map inputs: directly uses the map as data</li>
223+
* <li>For List inputs: wraps the list in a map with "output" as the key</li>
224+
* <li>For other types: converts to JSON string and attempts to parse as map,
225+
* if parsing fails, wraps the original output in a map with "output" as the key</li>
226+
* </ul>
227+
*
228+
* @param output The output object to be converted. Can be a Map, List, or any other object
229+
* @param outputKey The key/name to be assigned to the resulting ModelTensor
230+
* @return A ModelTensor containing the formatted output data
231+
*/
232+
public static ModelTensor convertOutputToModelTensor(Object output, String outputKey) {
233+
ModelTensor modelTensor;
234+
if (output instanceof Map) {
235+
modelTensor = ModelTensor.builder().name(outputKey).dataAsMap((Map) output).build();
236+
} else if (output instanceof List) {
237+
Map<String, Object> resultMap = new HashMap<>();
238+
resultMap.put("output", output);
239+
modelTensor = ModelTensor.builder().name(outputKey).dataAsMap(resultMap).build();
240+
} else {
241+
String outputJson = StringUtils.toJson(output);
242+
Map<String, Object> resultMap;
243+
if (StringUtils.isJson(outputJson)) {
244+
resultMap = StringUtils.fromJson(outputJson, "output");
245+
} else {
246+
resultMap = Map.of("output", output);
247+
}
248+
modelTensor = ModelTensor.builder().name(outputKey).dataAsMap(resultMap).build();
249+
}
250+
return modelTensor;
251+
}
218252
}

0 commit comments

Comments
 (0)