|
54 | 54 | import org.opensearch.core.action.ActionListener;
|
55 | 55 | import org.opensearch.core.action.ActionResponse;
|
56 | 56 | import org.opensearch.index.IndexSettings;
|
57 |
| -import org.opensearch.ml.common.output.model.ModelTensors; |
58 | 57 | import org.opensearch.ml.common.spi.tools.Parser;
|
59 | 58 | import org.opensearch.ml.common.spi.tools.Tool;
|
60 | 59 | import org.opensearch.ml.common.spi.tools.ToolAnnotation;
|
61 | 60 | import org.opensearch.ml.common.utils.ToolUtils;
|
| 61 | +import org.opensearch.ml.engine.tools.parser.ToolParser; |
62 | 62 | import org.opensearch.transport.client.Client;
|
63 | 63 |
|
64 | 64 | import lombok.Getter;
|
@@ -103,23 +103,14 @@ public class ListIndexTool implements Tool {
|
103 | 103 | @Setter
|
104 | 104 | private Parser<?, ?> inputParser;
|
105 | 105 | @Setter
|
106 |
| - private Parser<?, ?> outputParser; |
| 106 | + private Parser outputParser; |
107 | 107 | @SuppressWarnings("unused")
|
108 | 108 | private ClusterService clusterService;
|
109 | 109 |
|
110 | 110 | public ListIndexTool(Client client, ClusterService clusterService) {
|
111 | 111 | this.client = client;
|
112 | 112 | this.clusterService = clusterService;
|
113 | 113 |
|
114 |
| - outputParser = new Parser<>() { |
115 |
| - @Override |
116 |
| - public Object parse(Object o) { |
117 |
| - @SuppressWarnings("unchecked") |
118 |
| - List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o; |
119 |
| - return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response"); |
120 |
| - } |
121 |
| - }; |
122 |
| - |
123 | 114 | this.attributes = new HashMap<>();
|
124 | 115 | attributes.put(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA);
|
125 | 116 | attributes.put(STRICT_FIELD, false);
|
@@ -167,8 +158,12 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
|
167 | 158 | );
|
168 | 159 | }
|
169 | 160 | @SuppressWarnings("unchecked")
|
170 |
| - T response = (T) sb.toString(); |
171 |
| - listener.onResponse(response); |
| 161 | + T output = (T) sb.toString(); |
| 162 | + if (outputParser != null) { |
| 163 | + listener.onResponse((T) outputParser.parse(output)); |
| 164 | + } else { |
| 165 | + listener.onResponse((T) output); |
| 166 | + } |
172 | 167 | }, listener::onFailure));
|
173 | 168 |
|
174 | 169 | fetchClusterInfoAndPages(
|
@@ -463,8 +458,10 @@ public void init(Client client, ClusterService clusterService) {
|
463 | 458 | }
|
464 | 459 |
|
465 | 460 | @Override
|
466 |
| - public ListIndexTool create(Map<String, Object> map) { |
467 |
| - return new ListIndexTool(client, clusterService); |
| 461 | + public ListIndexTool create(Map<String, Object> params) { |
| 462 | + ListIndexTool tool = new ListIndexTool(client, clusterService); |
| 463 | + tool.setOutputParser(ToolParser.createFromToolParams(params)); |
| 464 | + return tool; |
468 | 465 | }
|
469 | 466 |
|
470 | 467 | @Override
|
|
0 commit comments