Skip to content

Commit db22e6f

Browse files
committed
add processor chain and add support for model and tool
Signed-off-by: Yaliang Wu <[email protected]>
1 parent 878bbcc commit db22e6f

29 files changed

+3359
-205
lines changed

common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ public <T> void parseResponse(T response, List<ModelTensor> modelTensors, boolea
111111
if (response instanceof String && isJson((String) response)) {
112112
Map<String, Object> data = StringUtils.fromJson((String) response, ML_MAP_RESPONSE_KEY);
113113
modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build());
114+
} else if (response instanceof Map) {
115+
modelTensors.add(ModelTensor.builder().name("response").dataAsMap((Map<String, ?>) response).build());
114116
} else {
115117
Map<String, Object> map = new HashMap<>();
116118
map.put("response", response);

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919
import java.util.Map;
2020

21+
import org.opensearch.common.xcontent.json.JsonXContent;
2122
import org.opensearch.core.common.io.stream.StreamInput;
2223
import org.opensearch.core.common.io.stream.StreamOutput;
2324
import org.opensearch.core.common.io.stream.Writeable;
@@ -289,4 +290,14 @@ public void writeTo(StreamOutput out) throws IOException {
289290
out.writeBoolean(false);
290291
}
291292
}
293+
294+
@Override
295+
public String toString() {
296+
try {
297+
return this.toXContent(JsonXContent.contentBuilder(), null).toString();
298+
} catch (IOException e) {
299+
throw new RuntimeException(e);
300+
}
301+
}
302+
292303
}

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.ArrayList;
1212
import java.util.List;
1313

14+
import org.opensearch.common.xcontent.json.JsonXContent;
1415
import org.opensearch.core.common.io.stream.StreamInput;
1516
import org.opensearch.core.common.io.stream.StreamOutput;
1617
import org.opensearch.core.xcontent.XContentBuilder;
@@ -102,4 +103,13 @@ public static ModelTensorOutput parse(XContentParser parser) throws IOException
102103

103104
return ModelTensorOutput.builder().mlModelOutputs(mlModelOutputs).build();
104105
}
106+
107+
@Override
108+
public String toString() {
109+
try {
110+
return this.toXContent(JsonXContent.contentBuilder(), null).toString();
111+
} catch (IOException e) {
112+
throw new RuntimeException(e);
113+
}
114+
}
105115
}

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.List;
1414

1515
import org.opensearch.common.io.stream.BytesStreamOutput;
16+
import org.opensearch.common.xcontent.json.JsonXContent;
1617
import org.opensearch.core.common.bytes.BytesReference;
1718
import org.opensearch.core.common.io.stream.StreamInput;
1819
import org.opensearch.core.common.io.stream.StreamOutput;
@@ -171,4 +172,13 @@ public static ModelTensors parse(XContentParser parser) throws IOException {
171172
modelTensors.setStatusCode(statusCode);
172173
return modelTensors;
173174
}
175+
176+
@Override
177+
public String toString() {
178+
try {
179+
return this.toXContent(JsonXContent.contentBuilder(), null).toString();
180+
} catch (IOException e) {
181+
throw new RuntimeException(e);
182+
}
183+
}
174184
}

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,15 @@
3333
import org.json.JSONObject;
3434
import org.opensearch.OpenSearchParseException;
3535
import org.opensearch.action.ActionRequestValidationException;
36+
import org.opensearch.ml.common.output.model.ModelTensor;
37+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
38+
import org.opensearch.ml.common.output.model.ModelTensors;
3639

3740
import com.fasterxml.jackson.core.JsonProcessingException;
3841
import com.fasterxml.jackson.databind.JsonNode;
3942
import com.fasterxml.jackson.databind.ObjectMapper;
4043
import com.google.gson.Gson;
44+
import com.google.gson.GsonBuilder;
4145
import com.google.gson.JsonElement;
4246
import com.google.gson.JsonObject;
4347
import com.google.gson.JsonParser;
@@ -74,7 +78,12 @@ public class StringUtils {
7478
public static final Gson gson;
7579

7680
static {
77-
gson = new Gson();
81+
gson = new GsonBuilder()
82+
.disableHtmlEscaping()
83+
.registerTypeAdapter(ModelTensor.class, new ToStringTypeAdapter<>(ModelTensor.class))
84+
.registerTypeAdapter(ModelTensorOutput.class, new ToStringTypeAdapter<>(ModelTensorOutput.class))
85+
.registerTypeAdapter(ModelTensors.class, new ToStringTypeAdapter<>(ModelTensors.class))
86+
.create();
7887
}
7988
public static final String TO_STRING_FUNCTION_NAME = ".toString()";
8089

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.utils;
7+
8+
import java.io.IOException;
9+
10+
import com.google.gson.TypeAdapter;
11+
import com.google.gson.stream.JsonReader;
12+
import com.google.gson.stream.JsonWriter;
13+
14+
public class ToStringTypeAdapter<T> extends TypeAdapter<T> {
15+
16+
private final Class<T> clazz;
17+
18+
public ToStringTypeAdapter(Class<T> clazz) {
19+
this.clazz = clazz;
20+
}
21+
22+
@Override
23+
public void write(JsonWriter out, T value) throws IOException {
24+
if (value == null) {
25+
out.nullValue();
26+
return;
27+
}
28+
String json = value.toString();
29+
out.jsonValue(json);
30+
}
31+
32+
@Override
33+
public T read(JsonReader in) throws IOException {
34+
throw new UnsupportedOperationException("Deserialization not supported");
35+
}
36+
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ToolUtils.java renamed to common/src/main/java/org/opensearch/ml/common/utils/ToolUtils.java

Lines changed: 114 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.engine.tools;
6+
package org.opensearch.ml.common.utils;
77

88
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
99
import static org.opensearch.ml.common.utils.StringUtils.gson;
1010

11-
import java.util.ArrayList;
1211
import java.util.HashMap;
1312
import java.util.List;
1413
import java.util.Map;
@@ -19,8 +18,6 @@
1918
import org.opensearch.ml.common.output.model.ModelTensor;
2019
import org.opensearch.ml.common.output.model.ModelTensorOutput;
2120
import org.opensearch.ml.common.output.model.ModelTensors;
22-
import org.opensearch.ml.common.spi.tools.Tool;
23-
import org.opensearch.ml.common.utils.StringUtils;
2421

2522
import com.google.gson.reflect.TypeToken;
2623
import com.jayway.jsonpath.JsonPath;
@@ -38,6 +35,19 @@ public class ToolUtils {
3835
public static final String TOOL_OUTPUT_FILTERS_FIELD = "output_filter";
3936
public static final String TOOL_REQUIRED_PARAMS = "required_parameters";
4037

38+
/**
39+
* Extracts required parameters based on tool attributes specification.
40+
* <p>
41+
* The method performs the following:
42+
* <ul>
43+
* <li>If required parameters are specified in attributes, only those parameters are extracted</li>
44+
* <li>If no required parameters are specified, all parameters are returned</li>
45+
* </ul>
46+
*
47+
* @param parameters The input parameters map to extract from
48+
* @param attributes The attributes map containing required parameter specifications
49+
* @return Map containing only the required parameters
50+
*/
4151
public static Map<String, String> extractRequiredParameters(Map<String, String> parameters, Map<String, ?> attributes) {
4252
Map<String, String> extractedParameters = new HashMap<>();
4353
if (parameters == null) {
@@ -56,6 +66,26 @@ public static Map<String, String> extractRequiredParameters(Map<String, String>
5666
return extractedParameters;
5767
}
5868

69+
/**
70+
* Extracts and processes input parameters, including handling "input" parameter.
71+
* <p>
72+
* The method performs the following steps:
73+
* <ol>
74+
* <li>Extracts required parameters based on tool attributes specification</li>
75+
* <li>If an "input" parameter exists:
76+
* <ul>
77+
* <li>Substitutes any parameter placeholders</li>
78+
* <li>Parses it as a JSON map</li>
79+
* <li>Merges the parsed values with other parameters</li>
80+
* </ul>
81+
* </li>
82+
* </ol>
83+
*
84+
* @param parameters The raw input parameters
85+
* @param attributes The tool attributes containing parameter specifications
86+
* @return Map of processed input parameters
87+
* @throws IllegalArgumentException if input JSON parsing fails
88+
*/
5989
public static Map<String, String> extractInputParameters(Map<String, String> parameters, Map<String, ?> attributes) {
6090
Map<String, String> extractedParameters = ToolUtils.extractRequiredParameters(parameters, attributes);
6191
if (extractedParameters.containsKey("input")) {
@@ -73,6 +103,22 @@ public static Map<String, String> extractInputParameters(Map<String, String> par
73103
return extractedParameters;
74104
}
75105

106+
/**
107+
* Builds the final parameter map for tool execution.
108+
* <p>
109+
* The method performs the following steps:
110+
* <ol>
111+
* <li>Combines tool specification parameters with input parameters</li>
112+
* <li>Processes tool-specific parameter prefixes</li>
113+
* <li>Applies configuration overrides from tool specification</li>
114+
* <li>Adds tenant identification</li>
115+
* </ol>
116+
*
117+
* @param parameters The input parameters to process
118+
* @param toolSpec The tool specification containing default parameters and configuration
119+
* @param tenantId The identifier for the tenant
120+
* @return Map of processed parameters ready for tool execution
121+
*/
76122
public static Map<String, String> buildToolParameters(Map<String, String> parameters, MLToolSpec toolSpec, String tenantId) {
77123
Map<String, String> executeParams = new HashMap<>();
78124
if (toolSpec.getParameters() != null) {
@@ -102,30 +148,14 @@ public static Map<String, String> buildToolParameters(Map<String, String> parame
102148
return executeParams;
103149
}
104150

105-
public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<String, String> executeParams, MLToolSpec toolSpec) {
106-
if (!toolFactories.containsKey(toolSpec.getType())) {
107-
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
108-
}
109-
Map<String, Object> toolParams = new HashMap<>();
110-
toolParams.putAll(executeParams);
111-
Map<String, Object> runtimeResources = toolSpec.getRuntimeResources();
112-
if (runtimeResources != null) {
113-
toolParams.putAll(runtimeResources);
114-
}
115-
Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams);
116-
String toolName = getToolName(toolSpec);
117-
tool.setName(toolName);
118-
119-
if (toolSpec.getDescription() != null) {
120-
tool.setDescription(toolSpec.getDescription());
121-
}
122-
if (executeParams.containsKey(toolName + ".description")) {
123-
tool.setDescription(executeParams.get(toolName + ".description"));
124-
}
125-
126-
return tool;
127-
}
128-
151+
/**
152+
* Filters tool output based on specified output filters in tool parameters.
153+
* Uses JSONPath expressions to extract specific portions of the response.
154+
*
155+
* @param toolParams The tool parameters containing output filter specifications
156+
* @param response The raw tool response to filter
157+
* @return Filtered output if successful, original response if filtering fails
158+
*/
129159
public static Object filterToolOutput(Map<String, String> toolParams, Object response) {
130160
if (toolParams != null && toolParams.containsKey(TOOL_OUTPUT_FILTERS_FIELD)) {
131161
try {
@@ -142,6 +172,20 @@ public static Object filterToolOutput(Map<String, String> toolParams, Object res
142172
return response;
143173
}
144174

175+
/**
176+
* Parses different types of tool responses into a JSON string representation.
177+
* <p>
178+
* Handles the following special cases:
179+
* <ul>
180+
* <li>ModelTensors - converts to XContent JSON representation</li>
181+
* <li>ModelTensor - converts to XContent JSON representation</li>
182+
* <li>ModelTensorOutput - converts to XContent JSON representation</li>
183+
* <li>Other types - converts to generic JSON string</li>
184+
* </ul>
185+
*
186+
* @param output The tool output object to parse
187+
* @return JSON string representation of the output
188+
*/
145189
public static String parseResponse(Object output) {
146190
try {
147191
if (output instanceof List && !((List) output).isEmpty() && ((List) output).get(0) instanceof ModelTensors) {
@@ -159,16 +203,49 @@ public static String parseResponse(Object output) {
159203
}
160204
}
161205

162-
public static List<String> getToolNames(Map<String, Tool> tools) {
163-
final List<String> inputTools = new ArrayList<>();
164-
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
165-
String toolName = entry.getValue().getName();
166-
inputTools.add(toolName);
167-
}
168-
return inputTools;
169-
}
170-
206+
/**
207+
* Gets the tool name from a tool specification.
208+
* Returns the specified name if available, otherwise returns the tool type.
209+
*
210+
* @param toolSpec The tool specification
211+
* @return The name of the tool
212+
*/
171213
public static String getToolName(MLToolSpec toolSpec) {
172214
return toolSpec.getName() != null ? toolSpec.getName() : toolSpec.getType();
173215
}
216+
217+
/**
218+
* Converts various types of tool output into a standardized ModelTensor format.
219+
* The conversion logic depends on the type of input:
220+
* <ul>
221+
* <li>For Map inputs: directly uses the map as data</li>
222+
* <li>For List inputs: wraps the list in a map with "output" as the key</li>
223+
* <li>For other types: converts to JSON string and attempts to parse as map,
224+
* if parsing fails, wraps the original output in a map with "output" as the key</li>
225+
* </ul>
226+
*
227+
* @param output The output object to be converted. Can be a Map, List, or any other object
228+
* @param outputKey The key/name to be assigned to the resulting ModelTensor
229+
* @return A ModelTensor containing the formatted output data
230+
*/
231+
public static ModelTensor convertOutputToModelTensor(Object output, String outputKey) {
232+
ModelTensor modelTensor;
233+
if (output instanceof Map) {
234+
modelTensor = ModelTensor.builder().name(outputKey).dataAsMap((Map) output).build();
235+
} else if (output instanceof List) {
236+
Map<String, Object> resultMap = new HashMap<>();
237+
resultMap.put("output", output);
238+
modelTensor = ModelTensor.builder().name(outputKey).dataAsMap(resultMap).build();
239+
} else {
240+
String outputJson = StringUtils.toJson(output);
241+
Map<String, Object> resultMap;
242+
if (StringUtils.isJson(outputJson)) {
243+
resultMap = StringUtils.fromJson(outputJson, "output");
244+
} else {
245+
resultMap = Map.of("output", output);
246+
}
247+
modelTensor = ModelTensor.builder().name(outputKey).dataAsMap(resultMap).build();
248+
}
249+
return modelTensor;
250+
}
174251
}

0 commit comments

Comments
 (0)