From 04c705daa37cdbeffa976762c9364a9eeacdb731 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Wed, 17 Sep 2025 20:56:39 +0000 Subject: [PATCH] Adding query planning tool search template validation and integration tests Signed-off-by: Joshua Palis --- .../tools/QueryPlanningPromptTemplate.java | 8 +- .../ml/engine/tools/QueryPlanningTool.java | 47 +++++++-- .../engine/tools/QueryPlanningToolTests.java | 32 ++++++ .../ml/rest/RestQueryPlanningToolIT.java | 97 +++++++++++++++++++ 4 files changed, 172 insertions(+), 12 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningPromptTemplate.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningPromptTemplate.java index c71422d40b..fc31c9eba6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningPromptTemplate.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningPromptTemplate.java @@ -201,14 +201,14 @@ public class QueryPlanningPromptTemplate { + "- If no perfect match exists, pick the closest by the criteria above. Never output “none” or invent an id."; public static final String TEMPLATE_SELECTION_INPUTS = "question: ${parameters.query_text}\n" - + "templates: ${parameters.search_templates}"; + + "search_templates: ${parameters.search_templates}"; public static final String TEMPLATE_SELECTION_EXAMPLES = "Example A: \n" + "question: 'what shoes are highly rated'\n" - + "templates:\n" + + "search_templates :\n" + "[\n" - + "{'id':'product-search-template','description':'Searches products in an e-commerce store.'},\n" - + "{'id':'sales-value-analysis-template','description':'Aggregates sales value for top-selling products.'}\n" + + "{'template_id':'product-search-template','template_description':'Searches products in an e-commerce store.'},\n" + + "{'template_id':'sales-value-analysis-template','template_description':'Aggregates sales value for top-selling products.'}\n" + "]\n" + "Example output : 'product-search-template'"; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java index 1549dd515b..8e0e6f3e96 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java @@ -27,6 +27,8 @@ import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.transport.client.Client; +import com.google.gson.reflect.TypeToken; + import lombok.Getter; import lombok.Setter; @@ -46,13 +48,16 @@ public class QueryPlanningTool implements WithModelTool { public static final String USER_PROMPT_FIELD = "user_prompt"; public static final String INDEX_MAPPING_FIELD = "index_mapping"; public static final String QUERY_FIELDS_FIELD = "query_fields"; - private static final String GENERATION_TYPE_FIELD = "generation_type"; + public static final String GENERATION_TYPE_FIELD = "generation_type"; private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated"; - private static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates"; - private static final String SEARCH_TEMPLATES_FIELD = "search_templates"; + public static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates"; + public static final String SEARCH_TEMPLATES_FIELD = "search_templates"; public static final String TEMPLATE_FIELD = "template"; + private static final String TEMPLATE_ID_FIELD = "template_id"; + private static final String TEMPLATE_DESCRIPTION_FIELD = "template_description"; private static final String DEFAULT_SYSTEM_PROMPT = "You are an OpenSearch Query DSL generation assistant, translating natural language questions to OpenSeach DSL Queries"; + @Getter private final String generationType; @Getter @@ -102,17 +107,19 @@ public void run(Map originalParameters, ActionListener li templateSelectionParameters.put(SEARCH_TEMPLATES_FIELD, searchTemplates); ActionListener templateSelectionListener = ActionListener.wrap(r -> { + // Default search template if LLM does not choose or if returned search template is null + parameters.put(TEMPLATE_FIELD, DEFAULT_SEARCH_TEMPLATE); try { String templateId = (String) r; if (templateId == null || templateId.isBlank() || templateId.equals("null")) { - // Default search template if LLM does not choose - parameters.put(TEMPLATE_FIELD, DEFAULT_SEARCH_TEMPLATE); executeQueryPlanning(parameters, listener); } else { // Retrieve search template by ID GetStoredScriptRequest getStoredScriptRequest = new GetStoredScriptRequest(templateId); client.admin().cluster().getStoredScript(getStoredScriptRequest, ActionListener.wrap(getStoredScriptResponse -> { - parameters.put(TEMPLATE_FIELD, gson.toJson(getStoredScriptResponse.getSource().getSource())); + if (getStoredScriptResponse.getSource() != null) { + parameters.put(TEMPLATE_FIELD, gson.toJson(getStoredScriptResponse.getSource().getSource())); + } executeQueryPlanning(parameters, listener); }, e -> { listener.onFailure(e); })); } @@ -233,14 +240,38 @@ public QueryPlanningTool create(Map map) { throw new IllegalArgumentException("search_templates field is required when generation_type is 'user_templates'"); } else { // array is parsed as a json string - searchTemplates = gson.toJson((String) map.get(SEARCH_TEMPLATES_FIELD)); - + String searchTemplatesJson = (String) map.get(SEARCH_TEMPLATES_FIELD); + validateSearchTemplates(searchTemplatesJson); + searchTemplates = gson.toJson(searchTemplatesJson); } } return new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates); } + private void validateSearchTemplates(Object searchTemplatesObj) { + List> templates = gson.fromJson(searchTemplatesObj.toString(), new TypeToken>>() { + }.getType()); + + for (Map template : templates) { + validateTemplateFields(template); + } + } + + private void validateTemplateFields(Map template) { + // Validate templateId + String templateId = template.get(TEMPLATE_ID_FIELD); + if (templateId == null || templateId.isBlank()) { + throw new IllegalArgumentException("search_templates field entries must have a template_id"); + } + + // Validate templateDescription + String templateDescription = template.get(TEMPLATE_DESCRIPTION_FIELD); + if (templateDescription == null || templateDescription.isBlank()) { + throw new IllegalArgumentException("search_templates field entries must have a template_description"); + } + } + @Override public String getDefaultDescription() { return DEFAULT_DESCRIPTION; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java index 44aa46cd89..5094ffce31 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/QueryPlanningToolTests.java @@ -95,6 +95,38 @@ public void testFactoryCreate() { assertEquals(QueryPlanningTool.TYPE, tool.getName()); } + @Test + public void testCreateWithInvalidSearchTemplatesDescription() throws IllegalArgumentException { + Map params = new HashMap<>(); + params.put("generation_type", "user_templates"); + params.put(MODEL_ID_FIELD, "test_model_id"); + params + .put( + SYSTEM_PROMPT_FIELD, + "You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}" + ); + params.put("query_text", "help me find some books related to wind"); + params.put("search_templates", "[{'template_id': 'template_id', 'template_des': 'test_description'}]"); + Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params)); + assertEquals("search_templates field entries must have a template_description", exception.getMessage()); + } + + @Test + public void testCreateWithInvalidSearchTemplatesID() throws IllegalArgumentException { + Map params = new HashMap<>(); + params.put("generation_type", "user_templates"); + params.put(MODEL_ID_FIELD, "test_model_id"); + params + .put( + SYSTEM_PROMPT_FIELD, + "You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}" + ); + params.put("query_text", "help me find some books related to wind"); + params.put("search_templates", "[{'templateid': 'template_id', 'template_description': 'test_description'}]"); + Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params)); + assertEquals("search_templates field entries must have a template_id", exception.getMessage()); + } + @Test public void testRun() throws ExecutionException, InterruptedException { String matchQueryString = "{\"query\":{\"match\":{\"title\":\"wind\"}}}"; diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestQueryPlanningToolIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestQueryPlanningToolIT.java index 2baf2a72cf..e280bc012e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestQueryPlanningToolIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestQueryPlanningToolIT.java @@ -6,7 +6,10 @@ package org.opensearch.ml.rest; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED; +import static org.opensearch.ml.engine.tools.QueryPlanningTool.GENERATION_TYPE_FIELD; import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD; +import static org.opensearch.ml.engine.tools.QueryPlanningTool.SEARCH_TEMPLATES_FIELD; +import static org.opensearch.ml.engine.tools.QueryPlanningTool.USER_SEARCH_TEMPLATES_TYPE_FIELD; import java.io.IOException; import java.util.List; @@ -95,6 +98,50 @@ public void testAgentWithQueryPlanningTool_DefaultPrompt() throws IOException { deleteAgent(agentId); } + @Test + public void testAgentWithQueryPlanningTool_SearchTemplates() throws IOException { + if (OPENAI_KEY == null) { + return; + } + + // Create Search Templates + String templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"match\":{\"type\":\"{{type}}\"}}}}}"; + Response response = createSearchTemplate("type_search_template", templateBody); + templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"term\":{\"type\":\"{{type}}\"}}}}}"; + response = createSearchTemplate("type_search_template_2", templateBody); + + // Register agent with search template IDs + String agentName = "Test_AgentWithQueryPlanningTool_SearchTemplates"; + String searchTemplates = "[{" + + "\"template_id\":\"type_search_template\"," + + "\"template_description\":\"this templates searches for flowers that match the given type this uses a match query\"" + + "},{" + + "\"template_id\":\"type_search_template_2\"," + + "\"template_description\":\"this templates searches for flowers that match the given type this uses a term query\"" + + "},{" + + "\"template_id\":\"brand_search_template\"," + + "\"template_description\":\"this templates searches for products that match the given brand\"" + + "}]"; + String agentId = registerQueryPlanningAgentWithSearchTemplates(agentName, queryPlanningModelId, searchTemplates); + assertNotNull(agentId); + + String query = "{\"parameters\": {\"query_text\": \"List 5 iris flowers of type setosa\"}}"; + Response agentResponse = executeAgent(agentId, query); + String responseBody = TestHelper.httpEntityToString(agentResponse.getEntity()); + + Map responseMap = gson.fromJson(responseBody, Map.class); + + List> inferenceResults = (List>) responseMap.get("inference_results"); + Map firstResult = inferenceResults.get(0); + List> outputArray = (List>) firstResult.get("output"); + Map output = (Map) outputArray.get(0); + String result = output.get("result").toString(); + + assertTrue(result.contains("query")); + assertTrue(result.contains("term")); + deleteAgent(agentId); + } + private String registerAgentWithQueryPlanningTool(String agentName, String modelId) throws IOException { MLToolSpec listIndexTool = MLToolSpec .builder() @@ -125,6 +172,44 @@ private String registerAgentWithQueryPlanningTool(String agentName, String model return registerAgent(agentName, agent); } + private String registerQueryPlanningAgentWithSearchTemplates(String agentName, String modelId, String searchTemplates) + throws IOException { + MLToolSpec listIndexTool = MLToolSpec + .builder() + .type("ListIndexTool") + .name("MyListIndexTool") + .description("A tool for list indices") + .parameters(Map.of("index", IRIS_INDEX, "question", "what fields are in the index?")) + .includeOutputInAgentResponse(true) + .build(); + + MLToolSpec queryPlanningTool = MLToolSpec + .builder() + .type("QueryPlanningTool") + .name("MyQueryPlanningTool") + .description("A tool for planning queries") + .parameters( + Map + .ofEntries( + Map.entry(MODEL_ID_FIELD, modelId), + Map.entry(GENERATION_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD), + Map.entry(SEARCH_TEMPLATES_FIELD, searchTemplates) + ) + ) + .includeOutputInAgentResponse(true) + .build(); + + MLAgent agent = MLAgent + .builder() + .name(agentName) + .type("flow") + .description("Test agent with QueryPlanningTool") + .tools(List.of(listIndexTool, queryPlanningTool)) + .build(); + + return registerAgent(agentName, agent); + } + private String registerQueryPlanningModel() throws IOException, InterruptedException { String openaiModelName = "openai gpt-4o model " + randomAlphaOfLength(5); return registerRemoteModel(openaiConnectorEntity, openaiModelName, true); @@ -177,6 +262,18 @@ private Response executeAgent(String agentId, String query) throws IOException { ); } + private Response createSearchTemplate(String templateName, String templateBody) throws IOException { + return TestHelper + .makeRequest( + client(), + "PUT", + "/_scripts/" + templateName, + null, + new StringEntity(templateBody), + List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json")) + ); + } + private void deleteAgent(String agentId) throws IOException { TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", List.of()); }