Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,14 @@ public class QueryPlanningPromptTemplate {
+ "- If no perfect match exists, pick the closest by the criteria above. Never output “none” or invent an id.";

public static final String TEMPLATE_SELECTION_INPUTS = "question: ${parameters.query_text}\n"
+ "templates: ${parameters.search_templates}";
+ "search_templates: ${parameters.search_templates}";

public static final String TEMPLATE_SELECTION_EXAMPLES = "Example A: \n"
+ "question: 'what shoes are highly rated'\n"
+ "templates:\n"
+ "search_templates :\n"
+ "[\n"
+ "{'id':'product-search-template','description':'Searches products in an e-commerce store.'},\n"
+ "{'id':'sales-value-analysis-template','description':'Aggregates sales value for top-selling products.'}\n"
+ "{'template_id':'product-search-template','template_description':'Searches products in an e-commerce store.'},\n"
+ "{'template_id':'sales-value-analysis-template','template_description':'Aggregates sales value for top-selling products.'}\n"
+ "]\n"
+ "Example output : 'product-search-template'";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.opensearch.ml.common.utils.ToolUtils;
import org.opensearch.transport.client.Client;

import com.google.gson.reflect.TypeToken;

import lombok.Getter;
import lombok.Setter;

Expand All @@ -46,13 +48,16 @@ public class QueryPlanningTool implements WithModelTool {
public static final String USER_PROMPT_FIELD = "user_prompt";
public static final String INDEX_MAPPING_FIELD = "index_mapping";
public static final String QUERY_FIELDS_FIELD = "query_fields";
private static final String GENERATION_TYPE_FIELD = "generation_type";
public static final String GENERATION_TYPE_FIELD = "generation_type";
private static final String LLM_GENERATED_TYPE_FIELD = "llmGenerated";
private static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates";
private static final String SEARCH_TEMPLATES_FIELD = "search_templates";
public static final String USER_SEARCH_TEMPLATES_TYPE_FIELD = "user_templates";
public static final String SEARCH_TEMPLATES_FIELD = "search_templates";
public static final String TEMPLATE_FIELD = "template";
private static final String TEMPLATE_ID_FIELD = "template_id";
private static final String TEMPLATE_DESCRIPTION_FIELD = "template_description";
private static final String DEFAULT_SYSTEM_PROMPT =
"You are an OpenSearch Query DSL generation assistant, translating natural language questions to OpenSeach DSL Queries";

@Getter
private final String generationType;
@Getter
Expand Down Expand Up @@ -102,17 +107,19 @@ public <T> void run(Map<String, String> originalParameters, ActionListener<T> li
templateSelectionParameters.put(SEARCH_TEMPLATES_FIELD, searchTemplates);

ActionListener<T> templateSelectionListener = ActionListener.wrap(r -> {
// Default search template if LLM does not choose or if returned search template is null
parameters.put(TEMPLATE_FIELD, DEFAULT_SEARCH_TEMPLATE);
try {
String templateId = (String) r;
if (templateId == null || templateId.isBlank() || templateId.equals("null")) {
// Default search template if LLM does not choose
parameters.put(TEMPLATE_FIELD, DEFAULT_SEARCH_TEMPLATE);
executeQueryPlanning(parameters, listener);
} else {
// Retrieve search template by ID
GetStoredScriptRequest getStoredScriptRequest = new GetStoredScriptRequest(templateId);
client.admin().cluster().getStoredScript(getStoredScriptRequest, ActionListener.wrap(getStoredScriptResponse -> {
parameters.put(TEMPLATE_FIELD, gson.toJson(getStoredScriptResponse.getSource().getSource()));
if (getStoredScriptResponse.getSource() != null) {
parameters.put(TEMPLATE_FIELD, gson.toJson(getStoredScriptResponse.getSource().getSource()));
}
executeQueryPlanning(parameters, listener);
}, e -> { listener.onFailure(e); }));
}
Expand Down Expand Up @@ -233,14 +240,38 @@ public QueryPlanningTool create(Map<String, Object> map) {
throw new IllegalArgumentException("search_templates field is required when generation_type is 'user_templates'");
} else {
// array is parsed as a json string
searchTemplates = gson.toJson((String) map.get(SEARCH_TEMPLATES_FIELD));

String searchTemplatesJson = (String) map.get(SEARCH_TEMPLATES_FIELD);
validateSearchTemplates(searchTemplatesJson);
searchTemplates = gson.toJson(searchTemplatesJson);
}
}

return new QueryPlanningTool(type, queryGenerationTool, client, searchTemplates);
}

private void validateSearchTemplates(Object searchTemplatesObj) {
List<Map<String, String>> templates = gson.fromJson(searchTemplatesObj.toString(), new TypeToken<List<Map<String, String>>>() {
}.getType());

for (Map<String, String> template : templates) {
validateTemplateFields(template);
}
}

private void validateTemplateFields(Map<String, String> template) {
// Validate templateId
String templateId = template.get(TEMPLATE_ID_FIELD);
if (templateId == null || templateId.isBlank()) {
throw new IllegalArgumentException("search_templates field entries must have a template_id");
}

// Validate templateDescription
String templateDescription = template.get(TEMPLATE_DESCRIPTION_FIELD);
if (templateDescription == null || templateDescription.isBlank()) {
throw new IllegalArgumentException("search_templates field entries must have a template_description");
}
}

@Override
public String getDefaultDescription() {
return DEFAULT_DESCRIPTION;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,38 @@ public void testFactoryCreate() {
assertEquals(QueryPlanningTool.TYPE, tool.getName());
}

@Test
public void testCreateWithInvalidSearchTemplatesDescription() throws IllegalArgumentException {
Map<String, Object> params = new HashMap<>();
params.put("generation_type", "user_templates");
params.put(MODEL_ID_FIELD, "test_model_id");
params
.put(
SYSTEM_PROMPT_FIELD,
"You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}"
);
params.put("query_text", "help me find some books related to wind");
params.put("search_templates", "[{'template_id': 'template_id', 'template_des': 'test_description'}]");
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
assertEquals("search_templates field entries must have a template_description", exception.getMessage());
}

@Test
public void testCreateWithInvalidSearchTemplatesID() throws IllegalArgumentException {
Map<String, Object> params = new HashMap<>();
params.put("generation_type", "user_templates");
params.put(MODEL_ID_FIELD, "test_model_id");
params
.put(
SYSTEM_PROMPT_FIELD,
"You are a query generation agent. Generate a dsl query for the following question: ${parameters.query_text}"
);
params.put("query_text", "help me find some books related to wind");
params.put("search_templates", "[{'templateid': 'template_id', 'template_description': 'test_description'}]");
Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
assertEquals("search_templates field entries must have a template_id", exception.getMessage());
}

@Test
public void testRun() throws ExecutionException, InterruptedException {
String matchQueryString = "{\"query\":{\"match\":{\"title\":\"wind\"}}}";
Expand Down Expand Up @@ -552,4 +584,101 @@ public void testFactoryCreateWhenAgenticSearchDisabled() {
Exception exception = assertThrows(OpenSearchException.class, () -> factory.create(map));
assertEquals(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE, exception.getMessage());
}

@Test
public void testCreateWithValidSearchTemplates() {
Map<String, Object> params = new HashMap<>();
params.put("generation_type", "user_templates");
params.put(MODEL_ID_FIELD, "test_model_id");
params
.put(
"search_templates",
"[{'template_id': 'template1', 'template_description': 'description1'}, {'template_id': 'template2', 'template_description': 'description2'}]"
);

QueryPlanningTool tool = factory.create(params);
assertNotNull(tool);
assertEquals("user_templates", tool.getGenerationType());
}

@Test
public void testCreateWithEmptySearchTemplatesList() {
Map<String, Object> params = new HashMap<>();
params.put("generation_type", "user_templates");
params.put(MODEL_ID_FIELD, "test_model_id");
params.put("search_templates", "[]");

QueryPlanningTool tool = factory.create(params);
assertNotNull(tool);
assertEquals("user_templates", tool.getGenerationType());
}

@Test
public void testCreateWithMissingSearchTemplatesField() {
Map<String, Object> params = new HashMap<>();
params.put("generation_type", "user_templates");
params.put(MODEL_ID_FIELD, "test_model_id");

Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
assertEquals("search_templates field is required when generation_type is 'user_templates'", exception.getMessage());
}

@Test
public void testCreateWithInvalidSearchTemplatesJson() {
Map<String, Object> params = new HashMap<>();
params.put("generation_type", "user_templates");
params.put(MODEL_ID_FIELD, "test_model_id");
params.put("search_templates", "invalid_json");

assertThrows(com.google.gson.JsonSyntaxException.class, () -> factory.create(params));
}

@Test
public void testCreateWithNullTemplateId() {
Map<String, Object> params = new HashMap<>();
params.put("generation_type", "user_templates");
params.put(MODEL_ID_FIELD, "test_model_id");
params.put("search_templates", "[{'template_id': null, 'template_description': 'description'}]");

Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
assertEquals("search_templates field entries must have a template_id", exception.getMessage());
}

@Test
public void testCreateWithBlankTemplateDescription() {
Map<String, Object> params = new HashMap<>();
params.put("generation_type", "user_templates");
params.put(MODEL_ID_FIELD, "test_model_id");
params.put("search_templates", "[{'template_id': 'template1', 'template_description': ' '}]");

Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
assertEquals("search_templates field entries must have a template_description", exception.getMessage());
}

@Test
public void testCreateWithMixedValidAndInvalidTemplates() {
Map<String, Object> params = new HashMap<>();
params.put("generation_type", "user_templates");
params.put(MODEL_ID_FIELD, "test_model_id");
params
.put(
"search_templates",
"[{'template_id': 'template1', 'template_description': 'description1'}, {'template_description': 'description2'}]"
);

Exception exception = assertThrows(IllegalArgumentException.class, () -> factory.create(params));
assertEquals("search_templates field entries must have a template_id", exception.getMessage());
}

@Test
public void testCreateWithExtraFieldsInSearchTemplates() {
Map<String, Object> params = new HashMap<>();
params.put("generation_type", "user_templates");
params.put(MODEL_ID_FIELD, "test_model_id");
params.put("search_templates", "[{'template_id': 'template1', 'template_description': 'description1', 'extra_field': 'value'}]");

QueryPlanningTool tool = factory.create(params);
assertNotNull(tool);
assertEquals("user_templates", tool.getGenerationType());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
package org.opensearch.ml.rest;

import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
import static org.opensearch.ml.engine.tools.QueryPlanningTool.GENERATION_TYPE_FIELD;
import static org.opensearch.ml.engine.tools.QueryPlanningTool.MODEL_ID_FIELD;
import static org.opensearch.ml.engine.tools.QueryPlanningTool.SEARCH_TEMPLATES_FIELD;
import static org.opensearch.ml.engine.tools.QueryPlanningTool.USER_SEARCH_TEMPLATES_TYPE_FIELD;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -95,6 +98,49 @@ public void testAgentWithQueryPlanningTool_DefaultPrompt() throws IOException {
deleteAgent(agentId);
}

@Test
public void testAgentWithQueryPlanningTool_SearchTemplates() throws IOException {
if (OPENAI_KEY == null) {
return;
}

// Create Search Templates
String templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"match\":{\"type\":\"{{type}}\"}}}}}";
Response response = createSearchTemplate("type_search_template", templateBody);
templateBody = "{\"script\":{\"lang\":\"mustache\",\"source\":{\"query\":{\"term\":{\"type\":\"{{type}}\"}}}}}";
response = createSearchTemplate("type_search_template_2", templateBody);

// Register agent with search template IDs
String agentName = "Test_AgentWithQueryPlanningTool_SearchTemplates";
String searchTemplates = "[{"
+ "\"template_id\":\"type_search_template\","
+ "\"template_description\":\"this templates searches for flowers that match the given type this uses a match query\""
+ "},{"
+ "\"template_id\":\"type_search_template_2\","
+ "\"template_description\":\"this templates searches for flowers that match the given type this uses a term query\""
+ "},{"
+ "\"template_id\":\"brand_search_template\","
+ "\"template_description\":\"this templates searches for products that match the given brand\""
+ "}]";
String agentId = registerQueryPlanningAgentWithSearchTemplates(agentName, queryPlanningModelId, searchTemplates);
assertNotNull(agentId);

String query = "{\"parameters\": {\"query_text\": \"List 5 iris flowers of type setosa\"}}";
Response agentResponse = executeAgent(agentId, query);
String responseBody = TestHelper.httpEntityToString(agentResponse.getEntity());

Map<String, Object> responseMap = gson.fromJson(responseBody, Map.class);

List<Map<String, Object>> inferenceResults = (List<Map<String, Object>>) responseMap.get("inference_results");
Map<String, Object> firstResult = inferenceResults.get(0);
List<Map<String, Object>> outputArray = (List<Map<String, Object>>) firstResult.get("output");
Map<String, Object> output = (Map<String, Object>) outputArray.get(0);
String result = output.get("result").toString();

assertTrue(result.contains("query"));
deleteAgent(agentId);
}

private String registerAgentWithQueryPlanningTool(String agentName, String modelId) throws IOException {
MLToolSpec listIndexTool = MLToolSpec
.builder()
Expand Down Expand Up @@ -125,6 +171,44 @@ private String registerAgentWithQueryPlanningTool(String agentName, String model
return registerAgent(agentName, agent);
}

private String registerQueryPlanningAgentWithSearchTemplates(String agentName, String modelId, String searchTemplates)
throws IOException {
MLToolSpec listIndexTool = MLToolSpec
.builder()
.type("ListIndexTool")
.name("MyListIndexTool")
.description("A tool for list indices")
.parameters(Map.of("index", IRIS_INDEX, "question", "what fields are in the index?"))
.includeOutputInAgentResponse(true)
.build();

MLToolSpec queryPlanningTool = MLToolSpec
.builder()
.type("QueryPlanningTool")
.name("MyQueryPlanningTool")
.description("A tool for planning queries")
.parameters(
Map
.ofEntries(
Map.entry(MODEL_ID_FIELD, modelId),
Map.entry(GENERATION_TYPE_FIELD, USER_SEARCH_TEMPLATES_TYPE_FIELD),
Map.entry(SEARCH_TEMPLATES_FIELD, searchTemplates)
)
)
.includeOutputInAgentResponse(true)
.build();

MLAgent agent = MLAgent
.builder()
.name(agentName)
.type("flow")
.description("Test agent with QueryPlanningTool")
.tools(List.of(listIndexTool, queryPlanningTool))
.build();

return registerAgent(agentName, agent);
}

private String registerQueryPlanningModel() throws IOException, InterruptedException {
String openaiModelName = "openai gpt-4o model " + randomAlphaOfLength(5);
return registerRemoteModel(openaiConnectorEntity, openaiModelName, true);
Expand Down Expand Up @@ -177,6 +261,18 @@ private Response executeAgent(String agentId, String query) throws IOException {
);
}

private Response createSearchTemplate(String templateName, String templateBody) throws IOException {
return TestHelper
.makeRequest(
client(),
"PUT",
"/_scripts/" + templateName,
null,
new StringEntity(templateBody),
List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json"))
);
}

private void deleteAgent(String agentId) throws IOException {
TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/agents/" + agentId, null, "", List.of());
}
Expand Down
Loading