Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ public class ToolMLInput extends MLInput {
@Setter
private String toolName;

@Getter
private Map<String, Object> originalParameters;

public ToolMLInput(StreamInput in) throws IOException {
super(in);
this.toolName = in.readString();
Expand Down Expand Up @@ -66,7 +69,9 @@ public ToolMLInput(XContentParser parser, FunctionName functionName) throws IOEx
toolName = parser.text();
break;
case PARAMETERS_FIELD:
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
Map<String, Object> rawParams = parser.map();
originalParameters = rawParams;
Map<String, String> parameters = StringUtils.getParameterMap(rawParams);
inputDataset = new RemoteInferenceInputDataSet(parameters);
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,10 @@ public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<Strin
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
}
Map<String, Object> toolParams = new HashMap<>();
toolParams.putAll(executeParams);
// Parse JSON strings back to original type since we need to validate each parameter type when creating tool
for (Map.Entry<String, String> entry : executeParams.entrySet()) {
toolParams.put(entry.getKey(), parseValue(entry.getValue()));
}
Map<String, Object> runtimeResources = toolSpec.getRuntimeResources();
if (runtimeResources != null) {
toolParams.putAll(runtimeResources);
Expand All @@ -1014,4 +1017,32 @@ public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<Strin

return tool;
}

private static Object parseValue(String value) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the benefit of parsing these values, if a tool requires different types like map or double/float, then they're still being passed to tool in String format.

if (value == null || "null".equals(value)) {
return null;
}
String v = value.trim();

// Try JSON array
if (v.startsWith("[") && v.endsWith("]")) {
try {
return gson.fromJson(v, List.class);
} catch (Exception e) {
return value;
}
}

// Try boolean
if ("true".equalsIgnoreCase(v) || "false".equalsIgnoreCase(v)) {
return Boolean.parseBoolean(v);
}

// Try integer
try {
return Integer.parseInt(v);
} catch (NumberFormatException e) {
return value;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ public void execute(Input input, ActionListener<Output> listener) {

try {
Map<String, String> mutableParams = new HashMap<>(parameters);
Tool tool = toolFactory.create(mutableParams);
Map<String, Object> originalParams = toolMLInput.getOriginalParameters();
Tool tool = toolFactory.create(originalParams);

if (!tool.validate(mutableParams)) {
listener.onFailure(new IllegalArgumentException("Invalid parameters for tool: " + toolName));
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.opensearch.action.ActionRequest;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
Expand All @@ -36,6 +37,9 @@
@ToolAnnotation(AgentTool.TYPE)
public class AgentTool implements Tool {
public static final String TYPE = "AgentTool";
public static final int DEFAULT_MAX_QUESTION_LENGTH = 10000;
public static final String MAX_QUESTION_LENGTH_FIELD = "max_question_length";
private int maxQuestionLength = DEFAULT_MAX_QUESTION_LENGTH;
private final Client client;

@Setter
Expand Down Expand Up @@ -117,6 +121,15 @@ public void setName(String s) {

@Override
public boolean validate(Map<String, String> parameters) {
if (parameters == null || parameters.isEmpty()) {
return false;
}

// Validate question length
String question = parameters.get("question");
if (question != null && question.length() > maxQuestionLength) {
throw new IllegalArgumentException("question length cannot exceed " + maxQuestionLength + " characters");
}
return true;
}

Expand Down Expand Up @@ -144,8 +157,11 @@ public void init(Client client) {

@Override
public AgentTool create(Map<String, Object> params) {
ConfigurationUtils.readStringProperty(TYPE, null, params, "question");
AgentTool agentTool = new AgentTool(client, (String) params.get("agent_id"));
agentTool.setOutputParser(ToolParser.createFromToolParams(params));
agentTool.maxQuestionLength = ConfigurationUtils
.readIntProperty(TYPE, null, params, MAX_QUESTION_LENGTH_FIELD, DEFAULT_MAX_QUESTION_LENGTH);
return agentTool;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.ActionRequest;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand Down Expand Up @@ -125,6 +126,7 @@ public void init(Client client) {

@Override
public ConnectorTool create(Map<String, Object> params) {
ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, "response_filter");
ConnectorTool connectorTool = new ConnectorTool(client, (String) params.get(CONNECTOR_ID));
connectorTool.setOutputParser(ToolParser.createFromToolParams(params));
return connectorTool;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
Expand Down Expand Up @@ -55,6 +56,9 @@ public class IndexMappingTool implements Tool {
+ "\"required\":[\"index\"],"
+ "\"additionalProperties\":false}";
public static final Map<String, Object> DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, true);
public static final int DEFAULT_MAX_QUESTION_LENGTH = 10000;
public static final String MAX_QUESTION_LENGTH_FIELD = "max_question_length";
private int maxQuestionLength = DEFAULT_MAX_QUESTION_LENGTH;

@Setter
@Getter
Expand Down Expand Up @@ -175,7 +179,17 @@ public String getType() {

@Override
public boolean validate(Map<String, String> parameters) {
return parameters != null && !parameters.isEmpty() && parameters.containsKey("index");
if (parameters == null || parameters.isEmpty() || !parameters.containsKey("index")) {
return false;
}

// Validate question length
String question = parameters.get("question");
if (question != null && question.length() > maxQuestionLength) {
throw new IllegalArgumentException("question length cannot exceed " + maxQuestionLength + " characters");
}
Comment on lines +187 to +190
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would the index mapping tool, need a "question" attribute?
Looks like we are not using it anywhere

Copy link
Contributor Author

@nathaliellenaa nathaliellenaa Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use it during tool execution. Also, in the doc, it listed question as a required parameter.

POST /_plugins/_ml/agents/9X7xWI0Bpc3sThaJdY9i/_execute
{
  "parameters": {
    "index": [ "sample-ecommerce" ],
    "question": "What fields are in the sample-ecommerce index?"
  }
}

I tested and we can run this tool without the question parameter. I'll update the doc to make it optional.
We can keep the question length validation here since users can still pass the question parameter, but I'll modify this to optional

ConfigurationUtils.readStringProperty(TYPE, null, params, "question");

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the question is shown in the parameter map passed to tools, but question is an agent level parameter, let's say an agent runs 10 tools, we only need to validate the question once when agent runs, instead of validating it 10 times when each tool runs.


return true;
}

/**
Expand Down Expand Up @@ -212,8 +226,14 @@ public void init(Client client) {

@Override
public IndexMappingTool create(Map<String, Object> params) {
ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, "question");
ConfigurationUtils.readOptionalList(TYPE, null, params, "index");
ConfigurationUtils.readBooleanProperty(TYPE, null, params, "local", false);

IndexMappingTool indexMappingTool = new IndexMappingTool(client);
indexMappingTool.setOutputParser(ToolParser.createFromToolParams(params));
indexMappingTool.maxQuestionLength = ConfigurationUtils
.readIntProperty(TYPE, null, params, MAX_QUESTION_LENGTH_FIELD, DEFAULT_MAX_QUESTION_LENGTH);
return indexMappingTool;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.index.IndexSettings;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
Expand Down Expand Up @@ -86,6 +87,9 @@ public class ListIndexTool implements Tool {
+ "for example: [\\\"index1\\\", \\\"index2\\\"], use empty array [] to list all indices in the cluster\"}},"
+ "\"additionalProperties\":false}";
public static final Map<String, Object> DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false);
public static final int DEFAULT_MAX_QUESTION_LENGTH = 10000;
public static final String MAX_QUESTION_LENGTH_FIELD = "max_question_length";
private int maxQuestionLength = DEFAULT_MAX_QUESTION_LENGTH;
Comment on lines +90 to +92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, we don't need the question parameter. The tool just ignores the question param

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here, we use it during tool execution. Doc marked it as a required field. But looking at the code, we can actually run this tool without question parameter. I'll update the doc to make this parameter optional.

We can keep the question length validation here since users can still pass the question parameter, but I'll modify this to optional

ConfigurationUtils.readStringProperty(TYPE, null, params, "question");


@Setter
@Getter
Expand Down Expand Up @@ -415,7 +419,16 @@ public void onFailure(final Exception e) {

@Override
public boolean validate(Map<String, String> parameters) {
return parameters != null && !parameters.isEmpty();
if (parameters == null || parameters.isEmpty()) {
return false;
}

// Validate question length
String question = parameters.get("question");
if (question != null && question.length() > maxQuestionLength) {
throw new IllegalArgumentException("question length cannot exceed " + maxQuestionLength + " characters");
}
return true;
}

/**
Expand Down Expand Up @@ -455,8 +468,15 @@ public void init(Client client, ClusterService clusterService) {

@Override
public ListIndexTool create(Map<String, Object> params) {
ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, "question");
ConfigurationUtils.readOptionalList(TYPE, null, params, "indices");
ConfigurationUtils.readBooleanProperty(TYPE, null, params, "local", false);
ConfigurationUtils.readIntProperty(TYPE, null, params, "page_size", 100);

ListIndexTool tool = new ListIndexTool(client, clusterService);
tool.setOutputParser(ToolParser.createFromToolParams(params));
tool.maxQuestionLength = ConfigurationUtils
.readIntProperty(TYPE, null, params, MAX_QUESTION_LENGTH_FIELD, DEFAULT_MAX_QUESTION_LENGTH);
return tool;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import org.opensearch.action.ActionRequest;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
Expand Down Expand Up @@ -172,6 +173,8 @@ public void init(Client client) {

@Override
public MLModelTool create(Map<String, Object> map) {
ConfigurationUtils.readOptionalStringProperty(TYPE, null, map, "prompt");
ConfigurationUtils.readOptionalStringProperty(TYPE, null, map, "response_field");
String modelId = (String) map.get(MODEL_ID_FIELD);
String responseField = (String) map.getOrDefault(RESPONSE_FIELD, DEFAULT_RESPONSE_FIELD);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
Expand Down Expand Up @@ -121,6 +122,9 @@ public class QueryPlanningTool implements WithModelTool {
+ "}";

public static final Map<String, Object> DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false);
public static final int DEFAULT_MAX_QUESTION_LENGTH = 10000;
public static final String MAX_QUESTION_LENGTH_FIELD = "max_question_length";
private int maxQuestionLength = DEFAULT_MAX_QUESTION_LENGTH;

@Getter
@Setter
Expand Down Expand Up @@ -394,6 +398,12 @@ public boolean validate(Map<String, String> parameters) {
|| !parameters.containsKey(INDEX_NAME_FIELD)) {
return false;
}

// Validate question length
String question = parameters.get(QUESTION_FIELD);
if (question != null && question.length() > maxQuestionLength) {
throw new IllegalArgumentException("question length cannot exceed " + maxQuestionLength + " characters");
}
return true;
}

Expand All @@ -420,6 +430,14 @@ public void init(Client client) {

@Override
public QueryPlanningTool create(Map<String, Object> params) {
ConfigurationUtils.readStringProperty(TYPE, null, params, QUESTION_FIELD);
ConfigurationUtils.readStringProperty(TYPE, null, params, INDEX_NAME_FIELD);
ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, GENERATION_TYPE_FIELD);
ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, QUERY_PLANNER_SYSTEM_PROMPT_FIELD);
ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, QUERY_PLANNER_USER_PROMPT_FIELD);
ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, "embedding_model_id");
ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, "response_filter");
ConfigurationUtils.readOptionalList(TYPE, null, params, SEARCH_TEMPLATES_FIELD);

MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(params);

Expand Down Expand Up @@ -455,6 +473,8 @@ public QueryPlanningTool create(Map<String, Object> params) {
// Create parser with default extract_json processor + any custom processors
queryPlanningTool.setOutputParser(createParserWithDefaultExtractJson(params));

queryPlanningTool.maxQuestionLength = ConfigurationUtils
.readIntProperty(TYPE, null, params, MAX_QUESTION_LENGTH_FIELD, DEFAULT_MAX_QUESTION_LENGTH);
return queryPlanningTool;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Map;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.utils.StringUtils;
Expand Down Expand Up @@ -137,6 +138,7 @@ public void init() {}

@Override
public ReadFromScratchPadTool create(Map<String, Object> params) {
ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, PERSISTENT_NOTES_KEY);
return new ReadFromScratchPadTool();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
Expand Down Expand Up @@ -296,6 +297,7 @@ public void init(Client client, NamedXContentRegistry xContentRegistry) {

@Override
public SearchIndexTool create(Map<String, Object> params) {
ConfigurationUtils.readStringProperty(TYPE, null, params, INPUT_FIELD);
SearchIndexTool tool = new SearchIndexTool(client, xContentRegistry);
// Enhance the output parser with processors if configured
tool.setOutputParser(ToolParser.createFromToolParams(params));
Expand Down
Loading
Loading