Skip to content

Commit 79f7649

Browse files
committed
add processor to ListIndexTool
Signed-off-by: Yaliang Wu <[email protected]>
1 parent 32feba7 commit 79f7649

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@
5454
import org.opensearch.core.action.ActionListener;
5555
import org.opensearch.core.action.ActionResponse;
5656
import org.opensearch.index.IndexSettings;
57-
import org.opensearch.ml.common.output.model.ModelTensors;
5857
import org.opensearch.ml.common.spi.tools.Parser;
5958
import org.opensearch.ml.common.spi.tools.Tool;
6059
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
6160
import org.opensearch.ml.common.utils.ToolUtils;
61+
import org.opensearch.ml.engine.tools.parser.ToolParser;
6262
import org.opensearch.transport.client.Client;
6363

6464
import lombok.Getter;
@@ -103,23 +103,14 @@ public class ListIndexTool implements Tool {
103103
@Setter
104104
private Parser<?, ?> inputParser;
105105
@Setter
106-
private Parser<?, ?> outputParser;
106+
private Parser outputParser;
107107
@SuppressWarnings("unused")
108108
private ClusterService clusterService;
109109

110110
public ListIndexTool(Client client, ClusterService clusterService) {
111111
this.client = client;
112112
this.clusterService = clusterService;
113113

114-
outputParser = new Parser<>() {
115-
@Override
116-
public Object parse(Object o) {
117-
@SuppressWarnings("unchecked")
118-
List<ModelTensors> mlModelOutputs = (List<ModelTensors>) o;
119-
return mlModelOutputs.get(0).getMlModelTensors().get(0).getDataAsMap().get("response");
120-
}
121-
};
122-
123114
this.attributes = new HashMap<>();
124115
attributes.put(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA);
125116
attributes.put(STRICT_FIELD, false);
@@ -167,8 +158,12 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
167158
);
168159
}
169160
@SuppressWarnings("unchecked")
170-
T response = (T) sb.toString();
171-
listener.onResponse(response);
161+
T output = (T) sb.toString();
162+
if (outputParser != null) {
163+
listener.onResponse((T) outputParser.parse(output));
164+
} else {
165+
listener.onResponse((T) output);
166+
}
172167
}, listener::onFailure));
173168

174169
fetchClusterInfoAndPages(
@@ -463,8 +458,10 @@ public void init(Client client, ClusterService clusterService) {
463458
}
464459

465460
@Override
466-
public ListIndexTool create(Map<String, Object> map) {
467-
return new ListIndexTool(client, clusterService);
461+
public ListIndexTool create(Map<String, Object> params) {
462+
ListIndexTool tool = new ListIndexTool(client, clusterService);
463+
tool.setOutputParser(ToolParser.createFromToolParams(params));
464+
return tool;
468465
}
469466

470467
@Override

ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/ListIndexToolTests.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,33 @@ public void test_run_successful_2() {
132132
verifyResult(tool, createParameters(null, null, null, null));
133133
}
134134

135+
@Test
136+
public void test_run_with_output_parser() {
137+
mockUp();
138+
Map<String, Object> params = new HashMap<>();
139+
params.put("output_processors", Arrays.asList(Map.of("type", "regex_replace", "pattern", "index-1", "replacement", "test-index")));
140+
Tool tool = ListIndexTool.Factory.getInstance().create(params);
141+
142+
ActionListener<String> listener = mock(ActionListener.class);
143+
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
144+
tool.run(createParameters("[\"index-1\"]", "true", "10", "true"), listener);
145+
verify(listener).onResponse(captor.capture());
146+
assert captor.getValue().contains("test-index");
147+
assert !captor.getValue().contains("index-1");
148+
}
149+
150+
@Test
151+
public void test_run_without_output_parser() {
152+
mockUp();
153+
Tool tool = ListIndexTool.Factory.getInstance().create(Collections.emptyMap());
154+
155+
ActionListener<String> listener = mock(ActionListener.class);
156+
ArgumentCaptor<String> captor = ArgumentCaptor.forClass(String.class);
157+
tool.run(createParameters("[\"index-1\"]", "true", "10", "true"), listener);
158+
verify(listener).onResponse(captor.capture());
159+
assert captor.getValue().contains("index-1");
160+
}
161+
135162
private Map<String, String> createParameters(String indices, String local, String pageSize, String includeUnloadedSegments) {
136163
Map<String, String> parameters = new HashMap<>();
137164
if (indices != null) {

0 commit comments

Comments
 (0)