From fef8e706c9d1f44d64e697d0c8aa648bccfcd3b6 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Wed, 22 Oct 2025 16:57:09 -0700 Subject: [PATCH 01/58] agentic memory integration with agent framework (#4331) Signed-off-by: Dhrubo Saha --- .../org/opensearch/ml/common/MLAgentType.java | 2 +- .../opensearch/ml/common/MLMemoryType.java | 24 + .../opensearch/ml/common/agent/MLAgent.java | 16 +- .../transport/agent/MLAgentUpdateInput.java | 5 +- .../ml/common/MLAgentTypeTests.java | 4 +- .../agent/MLAgentUpdateInputTest.java | 2 +- .../agent/AgenticMemoryAdapter.java | 775 +++++++++++++++++ .../agent/ChatHistoryTemplateEngine.java | 55 ++ .../algorithms/agent/ChatMemoryAdapter.java | 124 +++ .../engine/algorithms/agent/ChatMessage.java | 37 + .../algorithms/agent/MLAgentExecutor.java | 668 +++++++++----- .../algorithms/agent/MLChatAgentRunner.java | 814 ++++++++++++++++-- .../SimpleChatHistoryTemplateEngine.java | 81 ++ .../ml/engine/memory/ChatMemoryAdapter.java | 0 .../ml/engine/memory/ChatMessage.java | 32 + .../agent/AgenticMemoryAdapterTest.java | 167 ++++ .../agent/ChatMemoryAdapterTest.java | 115 +++ .../agent/MLChatAgentRunnerTest.java | 132 +++ .../ml/helper/MemoryContainerHelper.java | 5 +- 19 files changed, 2740 insertions(+), 318 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/MLMemoryType.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapter.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatHistoryTemplateEngine.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapter.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMessage.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/SimpleChatHistoryTemplateEngine.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMemoryAdapter.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMessage.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapterTest.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapterTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/MLAgentType.java b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java index 2dd2614634..04a4b72014 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLAgentType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java @@ -20,7 +20,7 @@ public static MLAgentType from(String value) { try { return MLAgentType.valueOf(value.toUpperCase(Locale.ROOT)); } catch (Exception e) { - throw new IllegalArgumentException("Wrong Agent type"); + throw new IllegalArgumentException(value + " is not a valid Agent Type"); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java b/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java new file mode 100644 index 0000000000..31939ce1ca --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import java.util.Locale; + +public enum MLMemoryType { + CONVERSATION_INDEX, + AGENTIC_MEMORY; + + public static MLMemoryType from(String value) { + if (value != null) { + try { + return MLMemoryType.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong Memory type"); + } + } + return null; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index b66a23f11e..ec73d73856 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -15,7 +15,6 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -113,7 +112,7 @@ private void validate() { String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH) ); } - validateMLAgentType(type); + MLAgentType.from(type); if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null) { throw new IllegalArgumentException("We need model information for the conversational agent type"); } @@ -130,19 +129,6 @@ private void validate() { } } - private void validateMLAgentType(String agentType) { - if (type == null) { - throw new IllegalArgumentException("Agent type can't be null"); - } else { - try { - MLAgentType.valueOf(agentType.toUpperCase(Locale.ROOT)); // Use toUpperCase() to allow case-insensitive matching - } catch (IllegalArgumentException e) { - // The typeStr does not match any MLAgentType, so throw a new exception with a clearer message. - throw new IllegalArgumentException(agentType + " is not a valid Agent Type"); - } - } - } - public MLAgent(StreamInput input) throws IOException { Version streamInputVersion = input.getVersion(); name = input.readString(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java index 9a0d6002fd..e85b3f4bdc 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java @@ -26,6 +26,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; @@ -383,9 +384,7 @@ private void validate() { String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH) ); } - if (memoryType != null && !memoryType.equals("conversation_index")) { - throw new IllegalArgumentException(String.format("Invalid memory type: %s", memoryType)); - } + MLMemoryType.from(memoryType); if (tools != null) { Set toolNames = new HashSet<>(); for (MLToolSpec toolSpec : tools) { diff --git a/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java b/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java index ee15ca95fd..05f37c4992 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java @@ -44,14 +44,14 @@ public void testFromWithMixedCase() { public void testFromWithInvalidType() { // This should throw an IllegalArgumentException exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Wrong Agent type"); + exceptionRule.expectMessage(" is not a valid Agent Type"); MLAgentType.from("INVALID_TYPE"); } @Test public void testFromWithEmptyString() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Wrong Agent type"); + exceptionRule.expectMessage(" is not a valid Agent Type"); // This should also throw an IllegalArgumentException MLAgentType.from(""); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java index 72eb035279..084f95d137 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java @@ -94,7 +94,7 @@ public void testValidationWithInvalidMemoryType() { IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { MLAgentUpdateInput.builder().agentId("test-agent-id").name("test-agent").memoryType("invalid_type").build(); }); - assertEquals("Invalid memory type: invalid_type", e.getMessage()); + assertEquals("Wrong Memory type", e.getMessage()); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapter.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapter.java new file mode 100644 index 0000000000..6bf685bd7f --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapter.java @@ -0,0 +1,775 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.memorycontainer.MemoryType; +import org.opensearch.ml.common.memorycontainer.PayloadType; +import org.opensearch.ml.common.transport.memorycontainer.MLMemoryContainerGetAction; +import org.opensearch.ml.common.transport.memorycontainer.MLMemoryContainerGetRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Adapter for Agentic Memory system to work with MLChatAgentRunner. + * + *

This adapter provides a bridge between the ML Chat Agent system and the Agentic Memory + * infrastructure, enabling intelligent conversation management and context retention.

+ * + *

Memory Types Handled:

+ *
    + *
  • WORKING memory: Recent conversation history and active interactions
  • + *
  • LONG_TERM memory: Extracted facts, summaries, and semantic insights
  • + *
+ * + *

Key Features:

+ *
    + *
  • Dual memory querying for comprehensive context retrieval
  • + *
  • Dynamic inference configuration based on memory container LLM settings
  • + *
  • Structured trace data storage for tool invocation tracking
  • + *
  • Robust error handling with fallback mechanisms
  • + *
+ * + *

Usage Example:

+ *
{@code
+ * AgenticMemoryAdapter adapter = new AgenticMemoryAdapter(
+ *     client, "memory-container-id", "session-123", "user-456"
+ * );
+ * 
+ * // Retrieve conversation messages
+ * adapter.getMessages(ActionListener.wrap(
+ *     messages -> processMessages(messages),
+ *     error -> handleError(error)
+ * ));
+ * 
+ * // Save trace data
+ * adapter.saveTraceData("search_tool", "query", "results", 
+ *     "parent-id", 1, "search", listener);
+ * }
+ * + * @see ChatMemoryAdapter + * @see MLChatAgentRunner + */ +@Log4j2 +public class AgenticMemoryAdapter implements ChatMemoryAdapter { + private final Client client; + private final String memoryContainerId; + private final String sessionId; + private final String ownerId; + + /** + * Creates a new AgenticMemoryAdapter instance. + * + * @param client OpenSearch client for executing memory operations + * @param memoryContainerId Unique identifier for the memory container + * @param sessionId Session identifier for conversation context + * @param ownerId Owner/user identifier for access control + * @throws IllegalArgumentException if any required parameter is null + */ + public AgenticMemoryAdapter(Client client, String memoryContainerId, String sessionId, String ownerId) { + if (client == null) { + throw new IllegalArgumentException("Client cannot be null"); + } + if (memoryContainerId == null || memoryContainerId.trim().isEmpty()) { + throw new IllegalArgumentException("Memory container ID cannot be null or empty"); + } + if (sessionId == null || sessionId.trim().isEmpty()) { + throw new IllegalArgumentException("Session ID cannot be null or empty"); + } + if (ownerId == null || ownerId.trim().isEmpty()) { + throw new IllegalArgumentException("Owner ID cannot be null or empty"); + } + + this.client = client; + this.memoryContainerId = memoryContainerId; + this.sessionId = sessionId; + this.ownerId = ownerId; + } + + @Override + public void getMessages(ActionListener> listener) { + // Query both WORKING memory (recent conversations) and LONG_TERM memory + // (extracted facts) + // This provides both conversation history and learned context + + List allChatMessages = new ArrayList<>(); + AtomicInteger pendingRequests = new AtomicInteger(2); + + // 1. Get recent conversation history from WORKING memory + SearchSourceBuilder workingSearchBuilder = new SearchSourceBuilder() + .query( + QueryBuilders + .boolQuery() + .must(QueryBuilders.termQuery("namespace.session_id", sessionId)) + .must(QueryBuilders.termQuery("namespace.user_id", ownerId)) + ) + .sort("created_time", SortOrder.DESC) + .size(50); // Limit recent conversation history + + MLSearchMemoriesRequest workingRequest = MLSearchMemoriesRequest + .builder() + .mlSearchMemoriesInput( + MLSearchMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) + .searchSourceBuilder(workingSearchBuilder) + .build() + ) + .build(); + + client.execute(MLSearchMemoriesAction.INSTANCE, workingRequest, ActionListener.wrap(workingResponse -> { + synchronized (allChatMessages) { + allChatMessages.addAll(parseAgenticMemoryResponse(workingResponse)); + if (pendingRequests.decrementAndGet() == 0) { + // Sort all chat messages by timestamp and return + allChatMessages.sort((a, b) -> a.getTimestamp().compareTo(b.getTimestamp())); + listener.onResponse(allChatMessages); + } + } + }, listener::onFailure)); + + // 2. Get relevant context from LONG_TERM memory (extracted facts) + SearchSourceBuilder longTermSearchBuilder = new SearchSourceBuilder() + .query( + QueryBuilders + .boolQuery() + .must(QueryBuilders.termQuery("namespace.session_id", sessionId)) + .must(QueryBuilders.termQuery("namespace.user_id", ownerId)) + .should(QueryBuilders.termQuery("strategy_type", "SUMMARY")) + .should(QueryBuilders.termQuery("strategy_type", "SEMANTIC")) + ) + .sort("created_time", SortOrder.DESC) + .size(10); // Limit context facts + + MLSearchMemoriesRequest longTermRequest = MLSearchMemoriesRequest + .builder() + .mlSearchMemoriesInput( + MLSearchMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.LONG_TERM) + .searchSourceBuilder(longTermSearchBuilder) + .build() + ) + .build(); + + client.execute(MLSearchMemoriesAction.INSTANCE, longTermRequest, ActionListener.wrap(longTermResponse -> { + synchronized (allChatMessages) { + allChatMessages.addAll(parseAgenticMemoryResponse(longTermResponse)); + if (pendingRequests.decrementAndGet() == 0) { + // Sort all chat messages by timestamp and return + allChatMessages.sort((a, b) -> a.getTimestamp().compareTo(b.getTimestamp())); + listener.onResponse(allChatMessages); + } + } + }, e -> { + // If long-term memory fails, still return working memory results + log.warn("Failed to retrieve long-term memory, continuing with working memory only", e); + synchronized (allChatMessages) { + if (pendingRequests.decrementAndGet() == 0) { + listener.onResponse(allChatMessages); + } + } + })); + } + + @Override + public String getConversationId() { + return sessionId; + } + + @Override + public String getMemoryContainerId() { + return memoryContainerId; + } + + @Override + public void saveInteraction( + String question, + String assistantResponse, + String parentId, + Integer traceNum, + String action, + ActionListener listener + ) { + if (listener == null) { + throw new IllegalArgumentException("ActionListener cannot be null"); + } + final String finalQuestion = question != null ? question : ""; + final String finalAssistantResponse = assistantResponse != null ? assistantResponse : ""; + + log + .info( + "AgenticMemoryAdapter.saveInteraction: Called with parentId: {}, action: {}, hasResponse: {}", + parentId, + action, + !finalAssistantResponse.isEmpty() + ); + + // If parentId is provided and we have a response, update the existing + // interaction + if (parentId != null && !finalAssistantResponse.isEmpty()) { + log.info("AgenticMemoryAdapter.saveInteraction: Updating existing interaction {} with final response", parentId); + + // Update the existing interaction with the complete conversation + Map updateFields = new HashMap<>(); + updateFields.put("response", finalAssistantResponse); + updateFields.put("input", finalQuestion); + + updateInteraction(parentId, updateFields, ActionListener.wrap(res -> { + log.info("AgenticMemoryAdapter.saveInteraction: Successfully updated interaction {}", parentId); + listener.onResponse(parentId); // Return the same interaction ID + }, ex -> { + log + .error( + "AgenticMemoryAdapter.saveInteraction: Failed to update interaction {}, falling back to create new", + parentId, + ex + ); + // Fallback to creating new interaction if update fails + createNewInteraction(finalQuestion, finalAssistantResponse, parentId, traceNum, action, listener); + })); + } else { + // Create new interaction (root interaction or when no parentId) + log.info("AgenticMemoryAdapter.saveInteraction: Creating new interaction - parentId: {}, action: {}", parentId, action); + createNewInteraction(finalQuestion, finalAssistantResponse, parentId, traceNum, action, listener); + } + } + + private void createNewInteraction( + String question, + String assistantResponse, + String parentId, + Integer traceNum, + String action, + ActionListener listener + ) { + List messages = Arrays + .asList( + MessageInput.builder().role("user").content(createTextContent(question)).build(), + MessageInput.builder().role("assistant").content(createTextContent(assistantResponse)).build() + ); + + // Create namespace map with proper String types + Map namespaceMap = new java.util.HashMap<>(); + namespaceMap.put("session_id", sessionId); + namespaceMap.put("user_id", ownerId); + + Map metadataMap = new java.util.HashMap<>(); + if (traceNum != null) { + metadataMap.put("trace_num", traceNum.toString()); + } + if (action != null) { + metadataMap.put("action", action); + } + if (parentId != null) { + metadataMap.put("parent_id", parentId); + } + + // Check if memory container has LLM ID configured to determine infer value + hasLlmIdConfigured(ActionListener.wrap(hasLlmId -> { + MLAddMemoriesInput input = MLAddMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .messages(messages) + .namespace(namespaceMap) + .metadata(metadataMap) + .ownerId(ownerId) + .infer(hasLlmId) // Use dynamic infer based on LLM ID presence + .build(); + + MLAddMemoriesRequest request = new MLAddMemoriesRequest(input); + + client.execute(MLAddMemoriesAction.INSTANCE, request, ActionListener.wrap(addResponse -> { + log + .info( + "AgenticMemoryAdapter.createNewInteraction: Created interaction with ID: {}, sessionId: {}, action: {}, infer: {}", + addResponse.getWorkingMemoryId(), + addResponse.getSessionId(), + action, + hasLlmId + ); + listener.onResponse(addResponse.getWorkingMemoryId()); + }, listener::onFailure)); + }, ex -> { + log.warn("Failed to check LLM ID configuration for interaction, proceeding with infer=false", ex); + // Fallback to infer=false if we can't determine LLM ID status + MLAddMemoriesInput input = MLAddMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .messages(messages) + .namespace(namespaceMap) + .metadata(metadataMap) + .ownerId(ownerId) + .infer(false) + .build(); + + MLAddMemoriesRequest request = new MLAddMemoriesRequest(input); + + client.execute(MLAddMemoriesAction.INSTANCE, request, ActionListener.wrap(addResponse -> { + log + .info( + "AgenticMemoryAdapter.createNewInteraction: Created interaction with ID: {}, sessionId: {}, action: {}, infer: false (fallback)", + addResponse.getWorkingMemoryId(), + addResponse.getSessionId(), + action + ); + listener.onResponse(addResponse.getWorkingMemoryId()); + }, listener::onFailure)); + })); + } + + /** + * Save trace data as structured tool invocation information in working memory. + * + *

This method stores detailed information about tool executions, including inputs, + * outputs, and contextual metadata. The data is stored with appropriate tags and + * namespace information for later retrieval and analysis.

+ * + *

Important: This method always uses {@code infer=false} to prevent + * LLM-based long-term memory extraction from tool traces. Tool execution data is already + * structured and queryable, and extracting facts from intermediate steps would create + * fragmented, duplicate long-term memories. Semantic extraction happens only on final + * conversation interactions via {@link #saveInteraction}.

+ * + * @param toolName Name of the tool that was executed (required, non-empty) + * @param toolInput Input parameters passed to the tool (nullable, defaults to empty string) + * @param toolOutput Output/response from the tool execution (nullable, defaults to empty string) + * @param parentMemoryId Parent memory ID to associate this trace with (nullable) + * @param traceNum Trace sequence number for ordering (nullable) + * @param action Action/origin identifier for categorization (nullable) + * @param listener ActionListener to handle the response with the created memory ID + * @throws IllegalArgumentException if toolName is null/empty or listener is null + * @see #saveInteraction for conversational data that triggers long-term memory extraction + */ + @Override + public void saveTraceData( + String toolName, + String toolInput, + String toolOutput, + String parentMemoryId, + Integer traceNum, + String action, + ActionListener listener + ) { + if (toolName == null || toolName.trim().isEmpty()) { + throw new IllegalArgumentException("Tool name cannot be null or empty"); + } + if (listener == null) { + throw new IllegalArgumentException("ActionListener cannot be null"); + } + final String finalToolName = toolName; + + // Create tool invocation structured data + Map toolInvocation = new HashMap<>(); + toolInvocation.put("tool_name", finalToolName); + toolInvocation.put("tool_input", toolInput != null ? toolInput : ""); + toolInvocation.put("tool_output", toolOutput != null ? toolOutput : ""); + + Map structuredData = new HashMap<>(); + structuredData.put("tool_invocations", List.of(toolInvocation)); + + // Create namespace map + Map namespaceMap = new HashMap<>(); + namespaceMap.put("session_id", sessionId); + namespaceMap.put("user_id", ownerId); + + // Create metadata map + Map metadataMap = new HashMap<>(); + metadataMap.put("status", "checkpoint"); + if (traceNum != null) { + metadataMap.put("trace_num", traceNum.toString()); + } + if (action != null) { + metadataMap.put("action", action); + } + if (parentMemoryId != null) { + metadataMap.put("parent_memory_id", parentMemoryId); + } + + // Create tags map with trace-specific information + Map tagsMap = new HashMap<>(); + tagsMap.put("data_type", "trace"); + + if (action != null) { + tagsMap.put("topic", action); + } + + /* + * IMPORTANT: Tool trace data uses infer=false to prevent long-term memory extraction + * + * Rationale: + * 1. Tool traces are intermediate execution steps, not final user-facing content + * 2. Running LLM inference on tool traces would create fragmented, low-quality long-term memories + * 3. Multiple tool executions in a single conversation would generate redundant/duplicate facts + * 4. Tool trace data is already structured (tool_name, tool_input, tool_output) and queryable + * 5. Final conversation interactions (saveInteraction) will trigger proper semantic extraction + * + * Example problem if infer=true: + * User: "What's the weather in Seattle?" + * - Tool trace saved → LLM extracts: "User queried Seattle" (incomplete context) + * - Final response saved → LLM extracts: "User asked about Seattle weather" (complete context) + * Result: Duplicate/conflicting long-term memories + * + * By setting infer=false for tool traces: + * - Tool execution data remains queryable via structured data search + * - Long-term memory extraction happens only on final, contextually complete interactions + * - Cleaner, more accurate long-term memory without duplication + * - Reduced LLM inference costs and processing overhead + */ + executeTraceDataSave(structuredData, namespaceMap, metadataMap, tagsMap, false, finalToolName, action, listener); + } + + /** + * Execute the actual trace data save operation. + * + *

Note: The infer parameter is kept for potential future use cases where selective + * inference on tool traces might be needed, but currently always receives false to + * prevent duplicate long-term memory extraction.

+ * + * @param structuredData The structured data containing tool invocation information + * @param namespaceMap The namespace mapping for the memory + * @param metadataMap The metadata for the memory entry + * @param tagsMap The tags for the memory entry + * @param infer Whether to enable inference processing (currently always false for tool traces) + * @param toolName The name of the tool (for logging) + * @param action The action identifier (for logging) + * @param listener ActionListener to handle the response + */ + private void executeTraceDataSave( + Map structuredData, + Map namespaceMap, + Map metadataMap, + Map tagsMap, + boolean infer, + String toolName, + String action, + ActionListener listener + ) { + try { + MLAddMemoriesInput input = MLAddMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .structuredData(structuredData) + .namespace(namespaceMap) + .metadata(metadataMap) + .tags(tagsMap) + .ownerId(ownerId) + .payloadType(PayloadType.DATA) + .infer(infer) + .build(); + + MLAddMemoriesRequest request = new MLAddMemoriesRequest(input); + + client.execute(MLAddMemoriesAction.INSTANCE, request, ActionListener.wrap(addResponse -> { + log + .info( + "AgenticMemoryAdapter.saveTraceData: Successfully saved trace data with ID: {}, toolName: {}, action: {}, infer: {}", + addResponse.getWorkingMemoryId(), + toolName, + action, + infer + ); + listener.onResponse(addResponse.getWorkingMemoryId()); + }, ex -> { + log + .error( + "AgenticMemoryAdapter.saveTraceData: Failed to save trace data for tool: {}, action: {}, infer: {}. Error: {}", + toolName, + action, + infer, + ex.getMessage(), + ex + ); + listener.onFailure(ex); + })); + } catch (Exception e) { + log + .error( + "AgenticMemoryAdapter.saveTraceData: Exception while building trace data save request for tool: {}, action: {}", + toolName, + action, + e + ); + listener.onFailure(e); + } + } + + /** + * Check if the memory container has an LLM ID configured for inference + * @param callback ActionListener to handle the result (true if LLM ID exists, false otherwise) + */ + private void hasLlmIdConfigured(ActionListener callback) { + MLMemoryContainerGetRequest getRequest = MLMemoryContainerGetRequest.builder().memoryContainerId(memoryContainerId).build(); + + client.execute(MLMemoryContainerGetAction.INSTANCE, getRequest, ActionListener.wrap(response -> { + boolean hasLlmId = response.getMlMemoryContainer().getConfiguration().getLlmId() != null; + log.info("Memory container {} has LLM ID configured: {}", memoryContainerId, hasLlmId); + callback.onResponse(hasLlmId); + }, ex -> { + log + .warn( + "Failed to get memory container {} configuration, defaulting infer to false. Error: {}", + memoryContainerId, + ex.getMessage(), + ex + ); + callback.onResponse(false); + })); + } + + private List> createTextContent(String text) { + return List.of(Map.of("type", "text", "text", text)); + } + + private List parseAgenticMemoryResponse(SearchResponse response) { + List chatMessages = new ArrayList<>(); + + for (SearchHit hit : response.getHits().getHits()) { + Map source = hit.getSourceAsMap(); + + // Parse working memory documents (conversational format) + if ("conversational".equals(source.get("payload_type"))) { + @SuppressWarnings("unchecked") + List> messages = (List>) source.get("messages"); + if (messages != null && messages.size() >= 2) { + // Extract user question and assistant response + String question = extractMessageText(messages.get(0)); // user message + String assistantResponse = extractMessageText(messages.get(1)); // assistant message + + if (question != null && assistantResponse != null) { + // Add user message + ChatMessage userMessage = ChatMessage + .builder() + .id(hit.getId() + "_user") + .timestamp(Instant.ofEpochMilli((Long) source.get("created_time"))) + .sessionId(getSessionIdFromNamespace(source)) + .role("user") + .content(question) + .contentType("text") + .origin("agentic_memory_working") + .metadata( + Map + .of( + "payload_type", + source.get("payload_type"), + "memory_container_id", + source.get("memory_container_id"), + "namespace", + source.get("namespace"), + "tags", + source.get("tags") + ) + ) + .build(); + chatMessages.add(userMessage); + + // Add assistant message + ChatMessage assistantMessage = ChatMessage + .builder() + .id(hit.getId() + "_assistant") + .timestamp(Instant.ofEpochMilli((Long) source.get("created_time"))) + .sessionId(getSessionIdFromNamespace(source)) + .role("assistant") + .content(assistantResponse) + .contentType("text") + .origin("agentic_memory_working") + .metadata( + Map + .of( + "payload_type", + source.get("payload_type"), + "memory_container_id", + source.get("memory_container_id"), + "namespace", + source.get("namespace"), + "tags", + source.get("tags") + ) + ) + .build(); + chatMessages.add(assistantMessage); + } + } + } + // Parse long-term memory documents (extracted facts) + else if (source.containsKey("memory") && source.containsKey("strategy_type")) { + String memory = (String) source.get("memory"); + String strategyType = (String) source.get("strategy_type"); + + // Convert extracted facts to chat message format for context + ChatMessage contextMessage = ChatMessage + .builder() + .id(hit.getId()) + .timestamp(Instant.ofEpochMilli((Long) source.get("created_time"))) + .sessionId(sessionId) // Use current session + .role("system") // System context message + .content("Context (" + strategyType + "): " + memory) // The extracted fact with context + .contentType("context") + .origin("agentic_memory_longterm") + .metadata( + Map + .of( + "strategy_type", + strategyType, + "strategy_id", + source.get("strategy_id"), + "memory_container_id", + source.get("memory_container_id") + ) + ) + .build(); + chatMessages.add(contextMessage); + } + } + + // Sort by timestamp to maintain chronological order + chatMessages.sort((a, b) -> a.getTimestamp().compareTo(b.getTimestamp())); + + return chatMessages; + } + + private String extractMessageText(Map message) { + if (message == null) + return null; + + @SuppressWarnings("unchecked") + List> content = (List>) message.get("content"); + if (content != null && !content.isEmpty()) { + Map firstContent = content.get(0); + return (String) firstContent.get("text"); + } + return null; + } + + private String getSessionIdFromNamespace(Map source) { + @SuppressWarnings("unchecked") + Map namespace = (Map) source.get("namespace"); + return namespace != null ? (String) namespace.get("session_id") : null; + } + + @Override + public void updateInteraction(String interactionId, java.util.Map updateFields, ActionListener listener) { + if (listener == null) { + throw new IllegalArgumentException("ActionListener cannot be null"); + } + if (interactionId == null || interactionId.trim().isEmpty()) { + listener.onFailure(new IllegalArgumentException("Interaction ID is required and cannot be empty")); + return; + } + if (updateFields == null || updateFields.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Update fields are required and cannot be empty")); + return; + } + + try { + log + .info( + "AgenticMemoryAdapter.updateInteraction: CALLED - Updating interaction {} with fields: {} in memory container: {}", + interactionId, + updateFields.keySet(), + memoryContainerId + ); + + // Convert updateFields to the format expected by memory container API + Map updateContent = new java.util.HashMap<>(); + + // Handle the response field - this is the main field we need to update + if (updateFields.containsKey("response")) { + String response = (String) updateFields.get("response"); + String question = (String) updateFields.getOrDefault("input", ""); + + // For working memory updates, we need to provide the complete messages array + // with both user question and assistant response + List> messages = Arrays + .asList( + Map.of("role", "user", "content", createTextContent(question)), + Map.of("role", "assistant", "content", createTextContent(response)) + ); + + updateContent.put("messages", messages); + + log + .debug( + "AgenticMemoryAdapter.updateInteraction: Updating messages for interaction {} with question: '{}' and response length: {}", + interactionId, + question.length() > 50 ? question.substring(0, 50) + "..." : question, + response.length() + ); + } + + // Handle other fields that might be updated + if (updateFields.containsKey("additional_info")) { + updateContent.put("additional_info", updateFields.get("additional_info")); + } + + MLUpdateMemoryInput input = MLUpdateMemoryInput.builder().updateContent(updateContent).build(); + + MLUpdateMemoryRequest request = MLUpdateMemoryRequest + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) // We're updating working memory + .memoryId(interactionId) + .mlUpdateMemoryInput(input) + .build(); + + client.execute(MLUpdateMemoryAction.INSTANCE, request, ActionListener.wrap(updateResponse -> { + log + .debug( + "AgenticMemoryAdapter.updateInteraction: Successfully updated interaction {} in memory container: {}", + interactionId, + memoryContainerId + ); + listener.onResponse(null); + }, ex -> { + log + .error( + "AgenticMemoryAdapter.updateInteraction: Failed to update interaction {} in memory container {}", + interactionId, + memoryContainerId, + ex + ); + listener.onFailure(ex); + })); + + } catch (Exception e) { + log + .error( + "AgenticMemoryAdapter.updateInteraction: Exception while updating interaction {} in memory container {}", + interactionId, + memoryContainerId, + e + ); + listener.onFailure(e); + } + } + +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatHistoryTemplateEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatHistoryTemplateEngine.java new file mode 100644 index 0000000000..80743ba3c5 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatHistoryTemplateEngine.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import java.util.List; +import java.util.Map; + +/** + * Enhanced template system for ChatMessage-based memory types. + * Supports flexible templating with role-based formatting and metadata access. + */ +public interface ChatHistoryTemplateEngine { + /** + * Build chat history from ChatMessage list using template + * @param messages List of ChatMessage objects + * @param template Template string with placeholders + * @param context Additional context variables + * @return Formatted chat history string + */ + String buildChatHistory(List messages, String template, Map context); + + /** + * Get default template for basic chat history formatting + * @return Default template string + */ + default String getDefaultTemplate() { + return "{{#each messages}}{{role}}: {{content}}\n{{/each}}"; + } + + /** + * Get role-based template with enhanced formatting + * @return Role-based template string + */ + default String getRoleBasedTemplate() { + return """ + {{#each messages}} + {{#if (eq role 'user')}} + Human: {{content}} + {{else if (eq role 'assistant')}} + Assistant: {{content}} + {{else if (eq role 'system')}} + System: {{content}} + {{else if (eq role 'tool')}} + Tool Result: {{content}} + {{/if}} + {{#if metadata.confidence}} + (Confidence: {{metadata.confidence}}) + {{/if}} + {{/each}} + """; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapter.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapter.java new file mode 100644 index 0000000000..88e952c806 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapter.java @@ -0,0 +1,124 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import java.util.List; + +import org.opensearch.core.action.ActionListener; + +/** + * Common interface for modern memory types supporting ChatMessage-based interactions. + * + *

This interface provides a unified abstraction for different memory backend implementations, + * enabling consistent interaction patterns across various memory storage systems. It supports + * both conversation management and detailed trace data storage for comprehensive agent behavior + * tracking.

+ * + *

Supported Memory Types:

+ *
    + *
  • Agentic Memory - Local cluster-based intelligent memory system
  • + *
  • Remote Agentic Memory - Distributed agentic memory implementation
  • + *
  • Bedrock AgentCore Memory - AWS Bedrock agent memory integration
  • + *
  • Future memory types - Extensible for additional implementations
  • + *
+ * + *

Core Capabilities:

+ *
    + *
  • Message retrieval in standardized ChatMessage format
  • + *
  • Conversation and session management
  • + *
  • Interaction persistence with metadata support
  • + *
  • Tool execution trace data storage
  • + *
  • Dynamic interaction updates
  • + *
+ * + *

Note: ConversationIndex uses a separate legacy pipeline for backward compatibility + * and is not part of this modern interface hierarchy.

+ * + * @see ChatMessage + * @see AgenticMemoryAdapter + */ +public interface ChatMemoryAdapter { + /** + * Retrieve conversation messages in ChatMessage format + * @param listener ActionListener to handle the response + */ + void getMessages(ActionListener> listener); + + /** + * Get the conversation/session identifier + * @return conversation ID or session ID + */ + String getConversationId(); + + /** + * This is the main memory container ID used to identify the memory container + * in the memory management system. + * @return + */ + String getMemoryContainerId(); + + /** + * Save interaction to memory (optional implementation) + * @param question User question + * @param response AI response + * @param parentId Parent interaction ID + * @param traceNum Trace number + * @param action Action performed + * @param listener ActionListener to handle the response + */ + default void saveInteraction( + String question, + String response, + String parentId, + Integer traceNum, + String action, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("Save not implemented")); + } + + /** + * Update existing interaction with additional information + * @param interactionId Interaction ID to update + * @param updateFields Fields to update (e.g., final answer, additional info) + * @param listener ActionListener to handle the response + */ + default void updateInteraction(String interactionId, java.util.Map updateFields, ActionListener listener) { + listener.onFailure(new UnsupportedOperationException("Update interaction not implemented")); + } + + /** + * Save trace data as tool invocation data in working memory. + * + *

This method provides a standardized way to store detailed information about + * tool executions, enabling comprehensive tracking and analysis of agent behavior. + * Implementations should store this data in a structured format that supports + * later retrieval and analysis.

+ * + *

Default implementation throws UnsupportedOperationException. Memory adapters + * that support trace data storage should override this method.

+ * + * @param toolName Name of the tool that was executed (required) + * @param toolInput Input parameters passed to the tool (may be null) + * @param toolOutput Output/response from the tool execution (may be null) + * @param parentMemoryId Parent memory ID to associate this trace with (may be null) + * @param traceNum Trace sequence number for ordering (may be null) + * @param action Action/origin identifier for categorization (may be null) + * @param listener ActionListener to handle the response with created trace ID + * @throws UnsupportedOperationException if the implementation doesn't support trace data storage + */ + default void saveTraceData( + String toolName, + String toolInput, + String toolOutput, + String parentMemoryId, + Integer traceNum, + String action, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("Save trace data not implemented")); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMessage.java new file mode 100644 index 0000000000..31dd72604d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMessage.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import java.time.Instant; +import java.util.Map; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; + +/** + * Enhanced memory message for chat agents - designed for extensibility. + * Supports multiple memory types: Agentic, Remote Agentic, Bedrock AgentCore, etc. + * + * Design Philosophy: + * - Text-first with rich metadata (hybrid approach) + * - Extensible for future multimodal content + * - Memory-type agnostic interface + * - Role-based message support + */ +@Builder +@AllArgsConstructor +@Getter +public class ChatMessage { + private String id; + private Instant timestamp; + private String sessionId; + private String role; // "user", "assistant", "system", "tool" + private String content; // Primary text content + private String contentType; // "text", "image", "tool_result", etc. (metadata) + private String origin; // "agentic_memory", "remote_agentic", "bedrock_agentcore", etc. + private Map metadata; // Rich content details and memory-specific data +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 1594506cf4..4b44c55738 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -27,6 +27,7 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.UUID; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchException; @@ -46,6 +47,7 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLAgentType; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; @@ -95,6 +97,7 @@ public class MLAgentExecutor implements Executable, SettingsChangeListener { public static final String MEMORY_ID = "memory_id"; + public static final String MEMORY_CONTAINER_ID = "memory_container_id"; public static final String QUESTION = "question"; public static final String PARENT_INTERACTION_ID = "parent_interaction_id"; public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id"; @@ -174,194 +177,211 @@ public void execute(Input input, ActionListener listener, TransportChann if (MLIndicesHandler.doesMultiTenantIndexExist(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), ML_AGENT_INDEX)) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - sdkClient - .getDataObjectAsync(getDataObjectRequest, client.threadPool().executor("opensearch_ml_general")) - .whenComplete((response, throwable) -> { - context.restore(); - log.debug("Completed Get Agent Request, Agent id:{}", agentId); - if (throwable != null) { - Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); - if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { - log.error("Failed to get Agent index", cause); - listener.onFailure(new OpenSearchStatusException("Failed to get agent index", RestStatus.NOT_FOUND)); - } else { - log.error("Failed to get ML Agent {}", agentId, cause); - listener.onFailure(cause); - } + sdkClient.getDataObjectAsync(getDataObjectRequest).whenComplete((response, throwable) -> { + context.restore(); + log.debug("Completed Get Agent Request, Agent id:{}", agentId); + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { + log.error("Failed to get Agent index", cause); + listener.onFailure(new OpenSearchStatusException("Failed to get agent index", RestStatus.NOT_FOUND)); } else { - try { - GetResponse getAgentResponse = response.parser() == null - ? null - : GetResponse.fromXContent(response.parser()); - if (getAgentResponse != null && getAgentResponse.isExists()) { - try ( - XContentParser parser = jsonXContent - .createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, - getAgentResponse.getSourceAsString() - ) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLAgent mlAgent = MLAgent.parse(parser); - if (isMultiTenancyEnabled && !Objects.equals(tenantId, mlAgent.getTenantId())) { - listener - .onFailure( - new OpenSearchStatusException( - "You don't have permission to access this resource", - RestStatus.FORBIDDEN - ) - ); - } - MLMemorySpec memorySpec = mlAgent.getMemory(); - String memoryId = inputDataSet.getParameters().get(MEMORY_ID); - String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); - String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID); - String appType = mlAgent.getAppType(); - String question = inputDataSet.getParameters().get(QUESTION); - - if (parentInteractionId != null && regenerateInteractionId != null) { - throw new IllegalArgumentException( - "Provide either `parent_interaction_id` to update an existing interaction, or `regenerate_interaction_id` to create a new one." + log.error("Failed to get ML Agent {}", agentId, cause); + listener.onFailure(cause); + } + } else { + try { + GetResponse getAgentResponse = response.parser() == null ? null : GetResponse.fromXContent(response.parser()); + if (getAgentResponse != null && getAgentResponse.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, + getAgentResponse.getSourceAsString() + ) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLAgent mlAgent = MLAgent.parse(parser); + if (isMultiTenancyEnabled && !Objects.equals(tenantId, mlAgent.getTenantId())) { + listener + .onFailure( + new OpenSearchStatusException( + "You don't have permission to access this resource", + RestStatus.FORBIDDEN + ) ); - } + } + MLMemorySpec memorySpec = mlAgent.getMemory(); + String memoryId; + if (Objects.equals(mlAgent.getMemory().getType(), MLMemoryType.CONVERSATION_INDEX.name())) { + memoryId = inputDataSet.getParameters().get(MEMORY_ID); + } else { + memoryId = inputDataSet.getParameters().get(MEMORY_CONTAINER_ID); + } - MLTask mlTask = MLTask - .builder() - .taskType(MLTaskType.AGENT_EXECUTION) - .functionName(FunctionName.AGENT) - .state(MLTaskState.CREATED) - .workerNodes(ImmutableList.of(clusterService.localNode().getId())) - .createTime(Instant.now()) - .lastUpdateTime(Instant.now()) - .async(false) - .tenantId(tenantId) - .build(); - - if (memoryId == null && regenerateInteractionId != null) { - throw new IllegalArgumentException("A memory ID must be provided to regenerate."); - } - if (memorySpec != null - && memorySpec.getType() != null - && memoryFactoryMap.containsKey(memorySpec.getType()) - && (memoryId == null || parentInteractionId == null)) { - ConversationIndexMemory.Factory conversationIndexMemoryFactory = - (ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType()); - conversationIndexMemoryFactory - .create(question, memoryId, appType, ActionListener.wrap(memory -> { - inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId()); - // get question for regenerate - if (regenerateInteractionId != null) { - log.info("Regenerate for existing interaction {}", regenerateInteractionId); - client - .execute( - GetInteractionAction.INSTANCE, - new GetInteractionRequest(regenerateInteractionId), - ActionListener.wrap(interactionRes -> { - inputDataSet - .getParameters() - .putIfAbsent(QUESTION, interactionRes.getInteraction().getInput()); - saveRootInteractionAndExecute( - listener, - memory, - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - channel - ); - }, e -> { - log.error("Failed to get existing interaction for regeneration", e); - listener.onFailure(e); - }) + String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); + String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID); + String appType = mlAgent.getAppType(); + String question = inputDataSet.getParameters().get(QUESTION); + + if (parentInteractionId != null && regenerateInteractionId != null) { + throw new IllegalArgumentException( + "Provide either `parent_interaction_id` to update an existing interaction, or `regenerate_interaction_id` to create a new one." + ); + } + + MLTask mlTask = MLTask + .builder() + .taskType(MLTaskType.AGENT_EXECUTION) + .functionName(FunctionName.AGENT) + .state(MLTaskState.CREATED) + .workerNodes(ImmutableList.of(clusterService.localNode().getId())) + .createTime(Instant.now()) + .lastUpdateTime(Instant.now()) + .async(false) + .tenantId(tenantId) + .build(); + + if (memoryId == null && regenerateInteractionId != null) { + throw new IllegalArgumentException("A memory ID must be provided to regenerate."); + } + + // NEW: Handle AGENTIC_MEMORY type before ConversationIndex logic + if (memorySpec != null && "AGENTIC_MEMORY".equals(memorySpec.getType())) { + log.debug("Detected AGENTIC_MEMORY type - routing to agentic memory handler"); + handleAgenticMemory( + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + listener, + channel + ); + } + // EXISTING: ConversationIndex logic remains unchanged + else if (memorySpec != null + && memorySpec.getType() != null + && memoryFactoryMap.containsKey(memorySpec.getType()) + && (memoryId == null || parentInteractionId == null)) { + ConversationIndexMemory.Factory conversationIndexMemoryFactory = + (ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType()); + conversationIndexMemoryFactory.create(question, memoryId, appType, ActionListener.wrap(memory -> { + inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId()); + // get question for regenerate + if (regenerateInteractionId != null) { + log.info("Regenerate for existing interaction {}", regenerateInteractionId); + client + .execute( + GetInteractionAction.INSTANCE, + new GetInteractionRequest(regenerateInteractionId), + ActionListener.wrap(interactionRes -> { + inputDataSet + .getParameters() + .putIfAbsent(QUESTION, interactionRes.getInteraction().getInput()); + saveRootInteractionAndExecute( + listener, + memory, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel ); - } else { - saveRootInteractionAndExecute( - listener, - memory, - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - channel - ); - } - }, ex -> { - log.error("Failed to read conversation memory", ex); - listener.onFailure(ex); - })); - } else { - // For existing conversations, create memory instance using factory - if (memorySpec != null && memorySpec.getType() != null) { - ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap - .get(memorySpec.getType()); - if (factory != null) { - // memoryId exists, so create returns an object with existing memory, therefore name can - // be null - factory - .create( - null, - memoryId, - appType, - ActionListener - .wrap( - createdMemory -> executeAgent( - inputDataSet, - mlTask, - isAsync, - memoryId, - mlAgent, - outputs, - modelTensors, - listener, - createdMemory, - channel - ), - ex -> { - log.error("Failed to find memory with memory_id: {}", memoryId, ex); - listener.onFailure(ex); - } - ) - ); - return; - } + }, e -> { + log.error("Failed to get existing interaction for regeneration", e); + listener.onFailure(e); + }) + ); + } else { + saveRootInteractionAndExecute( + listener, + memory, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel + ); + } + }, ex -> { + log.error("Failed to read conversation memory", ex); + listener.onFailure(ex); + })); + } else { + // For existing conversations, create memory instance using factory + if (memorySpec != null && memorySpec.getType() != null) { + ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap + .get(memorySpec.getType()); + if (factory != null) { + // memoryId exists, so create returns an object with existing memory, therefore name can + // be null + factory + .create( + null, + memoryId, + appType, + ActionListener + .wrap( + createdMemory -> executeAgent( + inputDataSet, + mlTask, + isAsync, + memoryId, + mlAgent, + outputs, + modelTensors, + listener, + createdMemory, + channel + ), + ex -> { + log.error("Failed to find memory with memory_id: {}", memoryId, ex); + listener.onFailure(ex); + } + ) + ); + return; } - executeAgent( - inputDataSet, - mlTask, - isAsync, - memoryId, - mlAgent, - outputs, - modelTensors, - listener, - null, - channel - ); } - } catch (Exception e) { - log.error("Failed to parse ml agent {}", agentId, e); - listener.onFailure(e); - } - } else { - listener - .onFailure( - new OpenSearchStatusException( - "Failed to find agent with the provided agent id: " + agentId, - RestStatus.NOT_FOUND - ) + executeAgent( + inputDataSet, + mlTask, + isAsync, + memoryId, + mlAgent, + outputs, + modelTensors, + listener, + null, + channel ); + } + } catch (Exception e) { + log.error("Failed to parse ml agent {}", agentId, e); + listener.onFailure(e); } - } catch (Exception e) { - log.error("Failed to get agent", e); - listener.onFailure(e); + } else { + listener + .onFailure( + new OpenSearchStatusException( + "Failed to find agent with the provided agent id: " + agentId, + RestStatus.NOT_FOUND + ) + ); } + } catch (Exception e) { + log.error("Failed to get agent", e); + listener.onFailure(e); } - }); + } + }); } } else { listener.onFailure(new ResourceNotFoundException("Agent index not found")); @@ -456,7 +476,7 @@ private void executeAgent( List outputs, List modelTensors, ActionListener listener, - ConversationIndexMemory memory, + Object memory, // Accept both ConversationIndexMemory and AgenticMemoryAdapter TransportChannel channel ) { String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null; @@ -472,12 +492,23 @@ private void executeAgent( // If async is true, index ML task and return the taskID. Also add memoryID to the task if it exists if (isAsync) { Map agentResponse = new HashMap<>(); - if (memoryId != null && !memoryId.isEmpty()) { - agentResponse.put(MEMORY_ID, memoryId); - } - if (parentInteractionId != null && !parentInteractionId.isEmpty()) { - agentResponse.put(PARENT_INTERACTION_ID, parentInteractionId); + // Handle different memory types for response + if (memory instanceof AgenticMemoryAdapter) { + AgenticMemoryAdapter adapter = (AgenticMemoryAdapter) memory; + agentResponse.put(MEMORY_ID, adapter.getMemoryContainerId()); // memory_container_id + agentResponse.put("session_id", adapter.getConversationId()); // session_id + if (parentInteractionId != null && !parentInteractionId.isEmpty()) { + agentResponse.put(PARENT_INTERACTION_ID, parentInteractionId); // actual interaction ID + } + } else { + // ConversationIndex behavior (unchanged) + if (memoryId != null && !memoryId.isEmpty()) { + agentResponse.put(MEMORY_ID, memoryId); + } + if (parentInteractionId != null && !parentInteractionId.isEmpty()) { + agentResponse.put(PARENT_INTERACTION_ID, parentInteractionId); + } } mlTask.setResponse(agentResponse); mlTask.setAsync(true); @@ -535,7 +566,7 @@ private ActionListener createAgentActionListener( List modelTensors, String agentType, String parentInteractionId, - ConversationIndexMemory memory + Object memory // Accept both ConversationIndexMemory and AgenticMemoryAdapter ) { return ActionListener.wrap(output -> { if (output != null) { @@ -556,7 +587,7 @@ private ActionListener createAsyncTaskUpdater( List outputs, List modelTensors, String parentInteractionId, - ConversationIndexMemory memory + Object memory // Accept both ConversationIndexMemory and AgenticMemoryAdapter ) { String taskId = mlTask.getTaskId(); Map agentResponse = new HashMap<>(); @@ -583,6 +614,7 @@ private ActionListener createAsyncTaskUpdater( e -> log.error("Failed to update ML task {} with agent execution results", taskId) ) ); + }, ex -> { agentResponse.put(ERROR_MESSAGE, ex.getMessage()); @@ -711,23 +743,259 @@ public void indexMLTask(MLTask mlTask, ActionListener listener) { } } - private void updateInteractionWithFailure(String interactionId, ConversationIndexMemory memory, String errorMessage) { - if (interactionId != null && memory != null) { - String failureMessage = "Agent execution failed: " + errorMessage; - Map updateContent = new HashMap<>(); - updateContent.put(RESPONSE_FIELD, failureMessage); + /** + * Handle agentic memory type requests + */ + private void handleAgenticMemory( + RemoteInferenceInputDataSet inputDataSet, + MLTask mlTask, + boolean isAsync, + List outputs, + List modelTensors, + MLAgent mlAgent, + ActionListener listener, + TransportChannel channel + ) { + // Extract parameters + String memoryContainerId = inputDataSet.getParameters().get("memory_container_id"); + String sessionId = inputDataSet.getParameters().get("session_id"); + String ownerId = inputDataSet.getParameters().get("owner_id"); + String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); + + log.debug("MLAgentExecutor: Processing AGENTIC_MEMORY request with parameters: {}", inputDataSet.getParameters().keySet()); + log + .debug( + "Extracted agentic memory parameters - memoryContainerId: {}, sessionId: {}, ownerId: {}, parentInteractionId: {}", + memoryContainerId != null ? "present" : "null", + sessionId != null ? "present" : "null", + ownerId != null ? "present" : "null", + parentInteractionId != null ? "present" : "null" + ); + + // Parameter validation + if (memoryContainerId == null) { + log + .error( + "AGENTIC_MEMORY validation failed: memory_container_id is null. Available params: {}", + inputDataSet.getParameters().keySet() + ); + listener.onFailure(new IllegalArgumentException("memory_container_id is required for agentic memory")); + return; + } + + if (ownerId == null) { + log.error("AGENTIC_MEMORY validation failed: owner_id is null. Available params: {}", inputDataSet.getParameters().keySet()); + listener.onFailure(new IllegalArgumentException("owner_id is required for agentic memory")); + return; + } + + log.debug("AGENTIC_MEMORY parameter validation successful - memoryContainerId: {}, ownerId: {}", memoryContainerId, ownerId); + + // Session management (same pattern as ConversationIndex) + boolean isNewConversation = Strings.isEmpty(sessionId) || parentInteractionId == null; + log + .debug( + "Conversation type determination - sessionId: {}, parentInteractionId: {}, isNewConversation: {}", + sessionId != null ? "present" : "null", + parentInteractionId != null ? "present" : "null", + isNewConversation + ); + + if (isNewConversation) { + if (Strings.isEmpty(sessionId)) { + sessionId = UUID.randomUUID().toString(); // NEW conversation + log.debug("Generated new agentic memory session: {}", sessionId); + } else { + log.debug("Using provided session ID for new conversation: {}", sessionId); + } + } else { + log + .debug( + "Continuing existing agentic memory conversation - sessionId: {}, parentInteractionId: {}", + sessionId, + parentInteractionId + ); + } - memory - .getMemoryManager() - .updateInteraction( + // Create AgenticMemoryAdapter + log + .debug( + "Creating AgenticMemoryAdapter with parameters - memoryContainerId: {}, sessionId: {}, ownerId: {}", + memoryContainerId, + sessionId, + ownerId + ); + try { + AgenticMemoryAdapter adapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); + log + .debug( + "AgenticMemoryAdapter created successfully - memoryContainerId: {}, sessionId: {}, conversationId: {}", + memoryContainerId, + sessionId, + adapter.getConversationId() + ); + + // Route to appropriate execution path + if (isNewConversation) { + // NEW conversation: create root interaction first + log + .debug( + "Execution path: NEW conversation - routing to saveRootInteractionAndExecuteAgentic for sessionId: {}", + sessionId + ); + saveRootInteractionAndExecuteAgentic( + listener, + adapter, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel + ); + } else { + // EXISTING conversation: execute directly + log + .debug( + "Execution path: EXISTING conversation - routing to executeAgent for sessionId: {}, parentInteractionId: {}", + sessionId, + parentInteractionId + ); + executeAgent( + inputDataSet, + mlTask, + isAsync, + adapter.getMemoryContainerId(), + mlAgent, + outputs, + modelTensors, + listener, + adapter, + channel + ); + } + } catch (Exception ex) { + log + .error( + "AgenticMemoryAdapter creation failed - memoryContainerId: {}, sessionId: {}, ownerId: {}, error: {}", + memoryContainerId, + sessionId, + ownerId, + ex.getMessage(), + ex + ); + listener.onFailure(ex); + } + } + + /** + * Create root interaction for new agentic memory conversations (mirrors ConversationIndex pattern for tool tracing support) + */ + private void saveRootInteractionAndExecuteAgentic( + ActionListener listener, + AgenticMemoryAdapter adapter, + RemoteInferenceInputDataSet inputDataSet, + MLTask mlTask, + boolean isAsync, + List outputs, + List modelTensors, + MLAgent mlAgent, + TransportChannel channel + ) { + String question = inputDataSet.getParameters().get(QUESTION); + + log + .debug( + "Creating root interaction for agentic memory - memoryContainerId: {}, sessionId: {}, question: {}", + adapter.getMemoryContainerId(), + adapter.getConversationId(), + question != null ? "present" : "null" + ); + + // Create root interaction with empty response (same pattern as ConversationIndex) + // This enables tool tracing and proper interaction updating + adapter.saveInteraction(question, "", null, 0, "ROOT", ActionListener.wrap(interactionId -> { + log + .info( + "Root interaction created successfully for agentic memory - interactionId: {}, memoryContainerId: {}, sessionId: {}", interactionId, - updateContent, - ActionListener - .wrap( - res -> log.info("Updated interaction {} with failure message", interactionId), - e -> log.warn("Failed to update interaction {} with failure message", interactionId, e) - ) + adapter.getMemoryContainerId(), + adapter.getConversationId() + ); + inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interactionId); + + log + .debug( + "Proceeding to executeAgent with root interaction - interactionId: {}, sessionId: {}", + interactionId, + adapter.getConversationId() + ); + + executeAgent( + inputDataSet, + mlTask, + isAsync, + adapter.getMemoryContainerId(), // Use memory_container_id as memoryId for agentic memory + mlAgent, + outputs, + modelTensors, + listener, + adapter, + channel + ); + }, ex -> { + log + .error( + "Root interaction creation failed for agentic memory - memoryContainerId: {}, sessionId: {}, error: {}", + adapter.getMemoryContainerId(), + adapter.getConversationId(), + ex.getMessage(), + ex ); + listener.onFailure(ex); + })); + } + + private void updateInteractionWithFailure(String interactionId, Object memory, String errorMessage) { + if (interactionId != null && memory != null) { + if (memory instanceof ConversationIndexMemory) { + // Existing ConversationIndex error handling + ConversationIndexMemory conversationMemory = (ConversationIndexMemory) memory; + String failureMessage = "Agent execution failed: " + errorMessage; + Map updateContent = new HashMap<>(); + updateContent.put(RESPONSE_FIELD, failureMessage); + + conversationMemory + .getMemoryManager() + .updateInteraction( + interactionId, + updateContent, + ActionListener + .wrap( + res -> log.info("Updated interaction {} with failure message", interactionId), + e -> log.warn("Failed to update interaction {} with failure message", interactionId, e) + ) + ); + } else if (memory instanceof AgenticMemoryAdapter) { + // New agentic memory error handling + AgenticMemoryAdapter adapter = (AgenticMemoryAdapter) memory; + Map updateFields = new HashMap<>(); + updateFields.put("error", errorMessage); + + adapter + .updateInteraction( + interactionId, + updateFields, + ActionListener + .wrap( + res -> log.info("Updated agentic memory interaction {} with failure message", interactionId), + e -> log.warn("Failed to update agentic memory interaction {} with failure message", interactionId, e) + ) + ); + } else { + log.warn("Unknown memory type for error handling: {}", memory.getClass().getSimpleName()); + } } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 7e1a4050bd..f22e295062 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -43,6 +43,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; @@ -57,6 +58,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; @@ -76,8 +78,6 @@ import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; -import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; -import org.opensearch.ml.repackage.com.google.common.collect.Lists; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; @@ -177,78 +177,60 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener functionCalling.configure(params); } - String memoryType = mlAgent.getMemory().getType(); - String memoryId = params.get(MLAgentExecutor.MEMORY_ID); - String appType = mlAgent.getAppType(); - String title = params.get(MLAgentExecutor.QUESTION); String chatHistoryPrefix = params.getOrDefault(PROMPT_CHAT_HISTORY_PREFIX, CHAT_HISTORY_PREFIX); String chatHistoryQuestionTemplate = params.get(CHAT_HISTORY_QUESTION_TEMPLATE); String chatHistoryResponseTemplate = params.get(CHAT_HISTORY_RESPONSE_TEMPLATE); int messageHistoryLimit = getMessageHistoryLimit(params); - ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); - conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.wrap(memory -> { - // TODO: call runAgent directly if messageHistoryLimit == 0 - memory.getMessages(ActionListener.>wrap(r -> { - List messageList = new ArrayList<>(); - for (Interaction next : r) { - String question = next.getInput(); - String response = next.getResponse(); - // As we store the conversation with empty response first and then update when have final answer, - // filter out those in-flight requests when run in parallel - if (Strings.isNullOrEmpty(response)) { - continue; - } - messageList - .add( - ConversationIndexMessage - .conversationIndexMessageBuilder() - .sessionId(memory.getConversationId()) - .question(question) - .response(response) - .build() - ); - } - if (!messageList.isEmpty()) { - if (chatHistoryQuestionTemplate == null) { - StringBuilder chatHistoryBuilder = new StringBuilder(); - chatHistoryBuilder.append(chatHistoryPrefix); - for (Message message : messageList) { - chatHistoryBuilder.append(message.toString()).append("\n"); - } - params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); - - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate - inputParams.put(CHAT_HISTORY, chatHistoryBuilder.toString()); - } else { - List chatHistory = new ArrayList<>(); - for (Message message : messageList) { - Map messageParams = new HashMap<>(); - messageParams.put("question", processTextDoc(((ConversationIndexMessage) message).getQuestion())); - - StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); - String chatQuestionMessage = substitutor.replace(chatHistoryQuestionTemplate); - chatHistory.add(chatQuestionMessage); - - messageParams.clear(); - messageParams.put("response", processTextDoc(((ConversationIndexMessage) message).getResponse())); - substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); - String chatResponseMessage = substitutor.replace(chatHistoryResponseTemplate); - chatHistory.add(chatResponseMessage); - } - params.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - params.put(NEW_CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate - inputParams.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - } - } + createMemoryAdapter(mlAgent, params, ActionListener.wrap(memoryOrAdapter -> { + log.debug("createMemoryAdapter callback: memoryOrAdapter type = {}", memoryOrAdapter.getClass().getSimpleName()); + + if (memoryOrAdapter instanceof ConversationIndexMemory) { + // Existing ConversationIndex flow - zero changes + ConversationIndexMemory memory = (ConversationIndexMemory) memoryOrAdapter; + memory.getMessages(ActionListener.>wrap(r -> { + processLegacyInteractions( + r, + memory.getConversationId(), + memory, + mlAgent, + params, + inputParams, + chatHistoryPrefix, + chatHistoryQuestionTemplate, + chatHistoryResponseTemplate, + functionCalling, + listener + ); + }, e -> { + log.error("Failed to get chat history", e); + listener.onFailure(e); + }), messageHistoryLimit); + + } else if (memoryOrAdapter instanceof ChatMemoryAdapter) { + // Modern Pipeline - NEW ChatMessage processing + log.debug("Routing to modern ChatMemoryAdapter pipeline"); + ChatMemoryAdapter adapter = (ChatMemoryAdapter) memoryOrAdapter; + adapter.getMessages(ActionListener.wrap(chatMessages -> { + // Use NEW ChatMessage-based processing (no conversion to Interaction) + processModernChatMessages( + chatMessages, + adapter.getConversationId(), + adapter, // Add ChatMemoryAdapter parameter + mlAgent, + params, + inputParams, + functionCalling, + listener + ); + }, e -> { + log.error("Failed to get chat history from modern memory adapter", e); + listener.onFailure(e); + })); - runAgent(mlAgent, params, listener, memory, memory.getConversationId(), functionCalling); - }, e -> { - log.error("Failed to get chat history", e); - listener.onFailure(e); - }), messageHistoryLimit); + } else { + listener.onFailure(new IllegalArgumentException("Unsupported memory type: " + memoryOrAdapter.getClass())); + } }, listener::onFailure)); } @@ -256,7 +238,7 @@ private void runAgent( MLAgent mlAgent, Map params, ActionListener listener, - Memory memory, + Object memoryOrSessionId, // Can be Memory object or String sessionId String sessionId, FunctionCalling functionCalling ) { @@ -267,7 +249,71 @@ private void runAgent( Map tools = new HashMap<>(); Map toolSpecMap = new HashMap<>(); createTools(toolFactories, params, allToolSpecs, tools, toolSpecMap, mlAgent); - runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, mlAgent.getTenantId(), listener, functionCalling); + + // Route to correct runReAct method based on memory type + if (memoryOrSessionId instanceof Memory) { + // Legacy ConversationIndex path + Memory actualMemory = (Memory) memoryOrSessionId; + runReAct( + mlAgent.getLlm(), + tools, + toolSpecMap, + params, + actualMemory, + sessionId, + mlAgent.getTenantId(), + listener, + functionCalling + ); + } else { + // Modern agentic memory path - create ChatMemoryAdapter + String memoryContainerId = params.get("memory_container_id"); + String ownerId = params.get("owner_id"); + + log + .debug( + "Agentic memory path: memoryContainerId={}, ownerId={}, sessionId={}, allParams={}", + memoryContainerId, + ownerId, + sessionId, + params.keySet() + ); + + if (memoryContainerId != null && ownerId != null) { + AgenticMemoryAdapter chatMemoryAdapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); + runReAct( + mlAgent.getLlm(), + tools, + toolSpecMap, + params, + chatMemoryAdapter, + sessionId, + mlAgent.getTenantId(), + listener, + functionCalling + ); + } else { + // Missing required parameters for agentic memory + log + .error( + "Agentic memory requested but missing required parameters. " + + "memory_container_id: {}, owner_id: {}, available params: {}", + memoryContainerId, + ownerId, + params.keySet() + ); + listener + .onFailure( + new IllegalArgumentException( + "Agentic memory requires 'memory_container_id' and 'owner_id' parameters. " + + "Provided: memory_container_id=" + + memoryContainerId + + ", owner_id=" + + ownerId + ) + ); + } + } }; // Fetch MCP tools and handle both success and failure cases @@ -387,17 +433,32 @@ private void runReAct( .build() ); - saveTraceData( - conversationIndexMemory, - memory.getType(), - question, - thoughtResponse, - sessionId, - traceDisabled, - parentInteractionId, - traceNumber, - "LLM" - ); + // Save trace data using appropriate memory adapter + if (memory instanceof ConversationIndexMemory) { + saveTraceData( + (ConversationIndexMemory) memory, + memory.getType(), + question, + thoughtResponse, + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + "LLM" + ); + } else if (memory instanceof ChatMemoryAdapter) { + saveTraceData( + (ChatMemoryAdapter) memory, + memory.getType(), + question, + thoughtResponse, + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + "LLM" + ); + } if (nextStepListener == null) { handleMaxIterationsReached( @@ -466,17 +527,32 @@ private void runReAct( ); scratchpadBuilder.append(toolResponse).append("\n\n"); - saveTraceData( - conversationIndexMemory, - "ReAct", - lastActionInput.get(), - outputToOutputString(filteredOutput), - sessionId, - traceDisabled, - parentInteractionId, - traceNumber, - lastAction.get() - ); + // Save trace data using appropriate memory adapter + if (memory instanceof ConversationIndexMemory) { + saveTraceData( + (ConversationIndexMemory) memory, + "ReAct", + lastActionInput.get(), + outputToOutputString(filteredOutput), + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + lastAction.get() + ); + } else if (memory instanceof ChatMemoryAdapter) { + saveTraceData( + (ChatMemoryAdapter) memory, + "ReAct", + lastActionInput.get(), + outputToOutputString(filteredOutput), + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + lastAction.get() + ); + } StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder), "${parameters.", "}"); newPrompt.set(substitutor.replace(finalPrompt)); @@ -581,7 +657,7 @@ private static void addToolOutputToAddtionalInfo( List list = (List) additionalInfo.get(toolOutputKey); list.add(outputString); } else { - additionalInfo.put(toolOutputKey, Lists.newArrayList(outputString)); + additionalInfo.put(toolOutputKey, new ArrayList<>(Collections.singletonList(outputString))); } } } @@ -705,6 +781,45 @@ public static void saveTraceData( } } + /** + * Overloaded saveTraceData method for ChatMemoryAdapter + */ + public static void saveTraceData( + ChatMemoryAdapter chatMemoryAdapter, + String memoryType, + String question, + String thoughtResponse, + String sessionId, + boolean traceDisabled, + String parentInteractionId, + AtomicInteger traceNumber, + String origin + ) { + if (chatMemoryAdapter != null && !traceDisabled) { + // Save trace data as tool invocation data in working memory + chatMemoryAdapter + .saveTraceData( + origin, // toolName (LLM, ReAct, etc.) + question, // toolInput + thoughtResponse, // toolOutput + parentInteractionId, // parentMemoryId + traceNumber.addAndGet(1), // traceNum + origin, // action + ActionListener + .wrap( + r -> log + .debug( + "Successfully saved trace data via ChatMemoryAdapter for session: {}, origin: {}", + sessionId, + origin + ), + e -> log + .warn("Failed to save trace data via ChatMemoryAdapter for session: {}, origin: {}", sessionId, origin, e) + ) + ); + } + } + private void sendFinalAnswer( String sessionId, ActionListener listener, @@ -759,6 +874,51 @@ private void sendFinalAnswer( } } + /** + * Overloaded sendFinalAnswer method for modern ChatMemoryAdapter pipeline + */ + private void sendFinalAnswer( + String sessionId, + ActionListener listener, + String question, + String parentInteractionId, + boolean verbose, + boolean traceDisabled, + List cotModelTensors, + ChatMemoryAdapter chatMemoryAdapter, // Modern parameter + AtomicInteger traceNumber, + Map additionalInfo, + String finalAnswer + ) { + // Send completion chunk for streaming + streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId); + + if (chatMemoryAdapter != null) { + String copyOfFinalAnswer = finalAnswer; + ActionListener saveTraceListener = ActionListener.wrap(r -> { + // For ChatMemoryAdapter, we don't have separate updateInteraction + // The saveInteraction method handles the complete saving + streamingWrapper + .sendFinalResponse( + sessionId, + listener, + parentInteractionId, + verbose, + cotModelTensors, + additionalInfo, + copyOfFinalAnswer + ); + }, listener::onFailure); + + // Use ChatMemoryAdapter's saveInteraction method + chatMemoryAdapter + .saveInteraction(question, finalAnswer, parentInteractionId, traceNumber.addAndGet(1), "LLM", saveTraceListener); + } else { + streamingWrapper + .sendFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, additionalInfo, finalAnswer); + } + } + public static List createModelTensors(String sessionId, String parentInteractionId) { List cotModelTensors = new ArrayList<>(); @@ -863,7 +1023,7 @@ public static void returnFinalResponse( ModelTensor .builder() .name("response") - .dataAsMap(ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) + .dataAsMap(Map.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) .build() ) ); @@ -908,6 +1068,305 @@ private void handleMaxIterationsReached( cleanUpResource(tools); } + /** + * Overloaded handleMaxIterationsReached method for ChatMemoryAdapter + */ + private void handleMaxIterationsReached( + String sessionId, + ActionListener listener, + String question, + String parentInteractionId, + boolean verbose, + boolean traceDisabled, + List traceTensors, + ChatMemoryAdapter chatMemoryAdapter, // Modern parameter + AtomicInteger traceNumber, + Map additionalInfo, + AtomicReference lastThought, + int maxIterations, + Map tools + ) { + String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) + ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + sendFinalAnswer( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + chatMemoryAdapter, // Use ChatMemoryAdapter instead of ConversationIndexMemory + traceNumber, + additionalInfo, + incompleteResponse + ); + cleanUpResource(tools); + } + + /** + * Complete runReAct method for modern ChatMemoryAdapter pipeline + * This method handles the new memory types (agentic, remote, bedrock, etc.) + * + * Full implementation with complete ReAct loop, tool execution, trace saving, and streaming. + */ + private void runReAct( + LLMSpec llm, + Map tools, + Map toolSpecMap, + Map parameters, + ChatMemoryAdapter chatMemoryAdapter, // Modern parameter + String sessionId, + String tenantId, + ActionListener listener, + FunctionCalling functionCalling + ) { + Map tmpParameters = constructLLMParams(llm, parameters); + String prompt = constructLLMPrompt(tools, tmpParameters); + tmpParameters.put(PROMPT, prompt); + final String finalPrompt = prompt; + + String question = tmpParameters.get(MLAgentExecutor.QUESTION); + String parentInteractionId = tmpParameters.get(MLAgentExecutor.PARENT_INTERACTION_ID); + boolean verbose = Boolean.parseBoolean(tmpParameters.getOrDefault(VERBOSE, "false")); + boolean traceDisabled = tmpParameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(tmpParameters.get(DISABLE_TRACE)); + + // Trace number + AtomicInteger traceNumber = new AtomicInteger(0); + + AtomicReference> lastLlmListener = new AtomicReference<>(); + AtomicReference lastThought = new AtomicReference<>(); + AtomicReference lastAction = new AtomicReference<>(); + AtomicReference lastActionInput = new AtomicReference<>(); + AtomicReference lastToolSelectionResponse = new AtomicReference<>(); + Map additionalInfo = new ConcurrentHashMap<>(); + Map lastToolParams = new ConcurrentHashMap<>(); + + StepListener firstListener = new StepListener(); + lastLlmListener.set(firstListener); + StepListener lastStepListener = firstListener; + + StringBuilder scratchpadBuilder = new StringBuilder(); + List interactions = new CopyOnWriteArrayList<>(); + + StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); + AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); + tmpParameters.put(PROMPT, newPrompt.get()); + List traceTensors = createModelTensors(sessionId, parentInteractionId); + int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, DEFAULT_MAX_ITERATIONS)); + + for (int i = 0; i < maxIterations; i++) { + int finalI = i; + StepListener nextStepListener = (i == maxIterations - 1) ? null : new StepListener<>(); + + lastStepListener.whenComplete(output -> { + StringBuilder sessionMsgAnswerBuilder = new StringBuilder(); + if (finalI % 2 == 0) { + MLTaskResponse llmResponse = (MLTaskResponse) output; + ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput(); + List llmResponsePatterns = gson.fromJson(tmpParameters.get("llm_response_pattern"), List.class); + Map modelOutput = parseLLMOutput( + parameters, + tmpModelTensorOutput, + llmResponsePatterns, + tools.keySet(), + interactions, + functionCalling + ); + + streamingWrapper.fixInteractionRole(interactions); + String thought = String.valueOf(modelOutput.get(THOUGHT)); + String toolCallId = String.valueOf(modelOutput.get("tool_call_id")); + String action = String.valueOf(modelOutput.get(ACTION)); + String actionInput = String.valueOf(modelOutput.get(ACTION_INPUT)); + String thoughtResponse = modelOutput.get(THOUGHT_RESPONSE); + String finalAnswer = modelOutput.get(FINAL_ANSWER); + + if (finalAnswer != null) { + finalAnswer = finalAnswer.trim(); + sendFinalAnswer( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + chatMemoryAdapter, // Use ChatMemoryAdapter instead of ConversationIndexMemory + traceNumber, + additionalInfo, + finalAnswer + ); + cleanUpResource(tools); + return; + } + + sessionMsgAnswerBuilder.append(thought); + lastThought.set(thought); + lastAction.set(action); + lastActionInput.set(actionInput); + lastToolSelectionResponse.set(thoughtResponse); + + traceTensors + .add( + ModelTensors + .builder() + .mlModelTensors(List.of(ModelTensor.builder().name("response").result(thoughtResponse).build())) + .build() + ); + + // Save trace data using ChatMemoryAdapter + saveTraceData( + chatMemoryAdapter, + "ChatMemoryAdapter", // Memory type for modern pipeline + question, + thoughtResponse, + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + "LLM" + ); + + if (nextStepListener == null) { + handleMaxIterationsReached( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + chatMemoryAdapter, // Use ChatMemoryAdapter + traceNumber, + additionalInfo, + lastThought, + maxIterations, + tools + ); + return; + } + + if (tools.containsKey(action)) { + Map toolParams = constructToolParams( + tools, + toolSpecMap, + question, + lastActionInput, + action, + actionInput + ); + lastToolParams.clear(); + lastToolParams.putAll(toolParams); + runTool( + tools, + toolSpecMap, + tmpParameters, + (ActionListener) nextStepListener, + action, + actionInput, + toolParams, + interactions, + toolCallId, + functionCalling + ); + + } else { + String res = String.format(Locale.ROOT, "Failed to run the tool %s which is unsupported.", action); + StringSubstitutor substitutor = new StringSubstitutor( + Map.of(SCRATCHPAD, scratchpadBuilder.toString()), + "${parameters.", + "}" + ); + newPrompt.set(substitutor.replace(finalPrompt)); + tmpParameters.put(PROMPT, newPrompt.get()); + ((ActionListener) nextStepListener).onResponse(res); + } + } else { + Object filteredOutput = filterToolOutput(lastToolParams, output); + addToolOutputToAddtionalInfo(toolSpecMap, lastAction, additionalInfo, filteredOutput); + + String toolResponse = constructToolResponse( + tmpParameters, + lastAction, + lastActionInput, + lastToolSelectionResponse, + filteredOutput + ); + scratchpadBuilder.append(toolResponse).append("\n\n"); + + // Save trace data for tool response using ChatMemoryAdapter + saveTraceData( + chatMemoryAdapter, + "ReAct", + lastActionInput.get(), + outputToOutputString(filteredOutput), + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + lastAction.get() + ); + + StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder), "${parameters.", "}"); + newPrompt.set(substitutor.replace(finalPrompt)); + tmpParameters.put(PROMPT, newPrompt.get()); + if (!interactions.isEmpty()) { + tmpParameters.put(INTERACTIONS, ", " + String.join(", ", interactions)); + } + + sessionMsgAnswerBuilder.append(outputToOutputString(filteredOutput)); + streamingWrapper.sendToolResponse(outputToOutputString(output), sessionId, parentInteractionId); + traceTensors + .add( + ModelTensors + .builder() + .mlModelTensors( + Collections + .singletonList( + ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build() + ) + ) + .build() + ); + + if (finalI == maxIterations - 1) { + handleMaxIterationsReached( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + chatMemoryAdapter, // Use ChatMemoryAdapter + traceNumber, + additionalInfo, + lastThought, + maxIterations, + tools + ); + return; + } + + if (nextStepListener != null) { + ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); + streamingWrapper.executeRequest(request, (ActionListener) nextStepListener); + } + } + }, listener::onFailure); + + if (i == 0) { + ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); + streamingWrapper.executeRequest(request, firstListener); + } + if (nextStepListener != null) { + lastStepListener = nextStepListener; + } + } + } + private void saveMessage( ConversationIndexMemory memory, String question, @@ -933,4 +1392,171 @@ private void saveMessage( memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener); } } + + /** + * Process modern ChatMessage format and build chat history using enhanced templates + */ + private void processModernChatMessages( + List chatMessages, + String sessionId, + ChatMemoryAdapter chatMemoryAdapter, // Add ChatMemoryAdapter parameter + MLAgent mlAgent, + Map params, + Map inputParams, + FunctionCalling functionCalling, + ActionListener listener + ) { + // Use new enhanced template system for ChatMessage + SimpleChatHistoryTemplateEngine templateEngine = new SimpleChatHistoryTemplateEngine(); + + // Filter out empty content messages (in-flight requests) + List validMessages = chatMessages + .stream() + .filter(msg -> msg.getContent() != null && !msg.getContent().trim().isEmpty()) + .toList(); + + if (!validMessages.isEmpty()) { + // Build chat history using enhanced template system + String chatHistory = templateEngine.buildSimpleChatHistory(validMessages); + params.put(CHAT_HISTORY, chatHistory); + inputParams.put(CHAT_HISTORY, chatHistory); + } + + // Run agent with modern processing (no Memory object needed) + runAgent(mlAgent, params, listener, sessionId, sessionId, functionCalling); + } + + /** + * Process legacy interactions (ConversationIndex) and build chat history, then run the agent + */ + private void processLegacyInteractions( + List interactions, + String sessionId, + ConversationIndexMemory memory, + MLAgent mlAgent, + Map params, + Map inputParams, + String chatHistoryPrefix, + String chatHistoryQuestionTemplate, + String chatHistoryResponseTemplate, + FunctionCalling functionCalling, + ActionListener listener + ) { + List messageList = new ArrayList<>(); + for (Interaction next : interactions) { + String question = next.getInput(); + String response = next.getResponse(); + // As we store the conversation with empty response first and then update when have final answer, + // filter out those in-flight requests when run in parallel + if (Strings.isNullOrEmpty(response)) { + continue; + } + messageList + .add( + ConversationIndexMessage + .conversationIndexMessageBuilder() + .sessionId(sessionId) + .question(question) + .response(response) + .build() + ); + } + + if (!messageList.isEmpty()) { + if (chatHistoryQuestionTemplate == null) { + StringBuilder chatHistoryBuilder = new StringBuilder(); + chatHistoryBuilder.append(chatHistoryPrefix); + for (Message message : messageList) { + chatHistoryBuilder.append(message.toString()).append("\n"); + } + params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); + + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + inputParams.put(CHAT_HISTORY, chatHistoryBuilder.toString()); + } else { + List chatHistory = new ArrayList<>(); + for (Message message : messageList) { + Map messageParams = new HashMap<>(); + messageParams.put("question", processTextDoc(((ConversationIndexMessage) message).getQuestion())); + + StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); + String chatQuestionMessage = substitutor.replace(chatHistoryQuestionTemplate); + chatHistory.add(chatQuestionMessage); + + messageParams.clear(); + messageParams.put("response", processTextDoc(((ConversationIndexMessage) message).getResponse())); + substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); + String chatResponseMessage = substitutor.replace(chatHistoryResponseTemplate); + chatHistory.add(chatResponseMessage); + } + params.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); + params.put(NEW_CHAT_HISTORY, String.join(", ", chatHistory) + ", "); + + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + inputParams.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); + } + } + + runAgent(mlAgent, params, listener, memory != null ? memory : sessionId, sessionId, functionCalling); + } + + /** + * Create appropriate memory adapter based on memory type + */ + private void createMemoryAdapter(MLAgent mlAgent, Map params, ActionListener listener) { + String memoryType = mlAgent.getMemory().getType(); + MLMemoryType type = MLMemoryType.from(memoryType); + + log.debug("MLChatAgentRunner.createMemoryAdapter: memoryType={}, params={}", memoryType, params.keySet()); + + switch (type) { + case CONVERSATION_INDEX: + // Keep existing flow - no adapter needed (zero risk approach) + ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); + String title = params.get(MLAgentExecutor.QUESTION); + String memoryId = params.get(MLAgentExecutor.MEMORY_ID); + String appType = mlAgent.getAppType(); + + factory.create(title, memoryId, appType, ActionListener.wrap(conversationMemory -> { + // Return ConversationIndexMemory directly - no conversion needed + listener.onResponse(conversationMemory); + }, listener::onFailure)); + break; + + case AGENTIC_MEMORY: + // New agentic memory path + String memoryContainerId = params.get("memory_container_id"); + String sessionId = params.get("session_id"); + String ownerId = params.get("owner_id"); // From user context + + log.debug("AGENTIC_MEMORY path: memoryContainerId={}, sessionId={}, ownerId={}", memoryContainerId, sessionId, ownerId); + + // Validate required parameters + if (memoryContainerId == null) { + log.error("AGENTIC_MEMORY validation failed: memory_container_id is null. Available params: {}", params.keySet()); + listener.onFailure(new IllegalArgumentException("memory_container_id is required for agentic memory")); + return; + } + + // Session management: same pattern as ConversationIndex + if (Strings.isEmpty(sessionId)) { + // CREATE NEW: Generate new session ID if not provided + sessionId = UUID.randomUUID().toString(); + log.debug("Created new agentic memory session: {}", sessionId); + } + // USE EXISTING: If sessionId provided, use it directly + + AgenticMemoryAdapter adapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); + log.debug("Created AgenticMemoryAdapter successfully: memoryContainerId={}, sessionId={}", memoryContainerId, sessionId); + listener.onResponse(adapter); + break; + + default: + // Future memory types will be added here: + // - REMOTE_AGENTIC_MEMORY: RemoteAgenticMemoryAdapter (similar format, different location) + // - BEDROCK_AGENTCORE: BedrockAgentCoreMemoryAdapter (format adapted in adapter) + // All future types will use modern ChatMessage pipeline + listener.onFailure(new IllegalArgumentException("Unsupported memory type: " + memoryType)); + } + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/SimpleChatHistoryTemplateEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/SimpleChatHistoryTemplateEngine.java new file mode 100644 index 0000000000..33399208e4 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/SimpleChatHistoryTemplateEngine.java @@ -0,0 +1,81 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import java.util.List; +import java.util.Map; + +/** + * Simple implementation of ChatHistoryTemplateEngine. + * Provides basic template functionality for ChatMessage formatting. + * + * This is a simplified implementation that supports: + * - Role-based message formatting + * - Basic placeholder replacement + * - Content type awareness + * + * Future versions can implement more advanced template engines (Handlebars, etc.) + */ +public class SimpleChatHistoryTemplateEngine implements ChatHistoryTemplateEngine { + + @Override + public String buildChatHistory(List messages, String template, Map context) { + if (messages == null || messages.isEmpty()) { + return ""; + } + + // For now, use a simple approach - build chat history with role-based formatting + StringBuilder chatHistory = new StringBuilder(); + + for (ChatMessage message : messages) { + String formattedMessage = formatMessage(message); + chatHistory.append(formattedMessage).append("\n"); + } + + return chatHistory.toString().trim(); + } + + /** + * Format a single ChatMessage based on its role and content type + */ + private String formatMessage(ChatMessage message) { + String role = message.getRole(); + String content = message.getContent(); + String contentType = message.getContentType(); + + // Role-based formatting + String prefix = switch (role) { + case "user" -> "Human: "; + case "assistant" -> "Assistant: "; + case "system" -> "System: "; + case "tool" -> "Tool Result: "; + default -> role + ": "; + }; + + // Content type awareness + String formattedContent = content; + if ("image".equals(contentType)) { + formattedContent = "[Image: " + content + "]"; + } else if ("tool_result".equals(contentType)) { + Map metadata = message.getMetadata(); + if (metadata != null && metadata.containsKey("tool_name")) { + formattedContent = "Tool " + metadata.get("tool_name") + ": " + content; + } + } else if ("context".equals(contentType)) { + // Context messages (like from long-term memory) get special formatting + formattedContent = "[Context] " + content; + } + + return prefix + formattedContent; + } + + /** + * Build chat history using default simple formatting + */ + public String buildSimpleChatHistory(List messages) { + return buildChatHistory(messages, getDefaultTemplate(), Map.of()); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMemoryAdapter.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMemoryAdapter.java new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMessage.java new file mode 100644 index 0000000000..0111642129 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMessage.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import java.util.Map; + +/** + * Interface for chat messages in the unified memory system. + * Provides a common abstraction for messages across different memory implementations. + */ +public interface ChatMessage { + /** + * Get the role of the message sender + * @return role such as "user", "assistant", "system" + */ + String getRole(); + + /** + * Get the content of the message + * @return message content + */ + String getContent(); + + /** + * Get additional metadata associated with the message + * @return metadata map + */ + Map getMetadata(); +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapterTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapterTest.java new file mode 100644 index 0000000000..f69f7d71e2 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapterTest.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; +import org.opensearch.transport.client.Client; + +/** + * Unit tests for AgenticMemoryAdapter. + */ +public class AgenticMemoryAdapterTest { + + @Mock + private Client client; + + private AgenticMemoryAdapter adapter; + private final String memoryContainerId = "test-memory-container"; + private final String sessionId = "test-session"; + private final String ownerId = "test-owner"; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + adapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithNullClient() { + new AgenticMemoryAdapter(null, memoryContainerId, sessionId, ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithNullMemoryContainerId() { + new AgenticMemoryAdapter(client, null, sessionId, ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithEmptyMemoryContainerId() { + new AgenticMemoryAdapter(client, "", sessionId, ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithNullSessionId() { + new AgenticMemoryAdapter(client, memoryContainerId, null, ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithEmptySessionId() { + new AgenticMemoryAdapter(client, memoryContainerId, "", ownerId); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithNullOwnerId() { + new AgenticMemoryAdapter(client, memoryContainerId, sessionId, null); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorWithEmptyOwnerId() { + new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ""); + } + + @Test + public void testGetConversationId() { + assertEquals(sessionId, adapter.getConversationId()); + } + + @Test + public void testGetMemoryContainerId() { + assertEquals(memoryContainerId, adapter.getMemoryContainerId()); + } + + @Test(expected = IllegalArgumentException.class) + public void testSaveTraceDataWithNullToolName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + adapter.saveTraceData(null, "input", "output", "parent-id", 1, "action", listener); + } + + @Test(expected = IllegalArgumentException.class) + public void testSaveTraceDataWithEmptyToolName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + adapter.saveTraceData("", "input", "output", "parent-id", 1, "action", listener); + } + + @Test(expected = IllegalArgumentException.class) + public void testSaveTraceDataWithNullListener() { + adapter.saveTraceData("tool", "input", "output", "parent-id", 1, "action", null); + } + + @Test(expected = IllegalArgumentException.class) + public void testSaveInteractionWithNullListener() { + adapter.saveInteraction("question", "response", null, 1, "action", null); + } + + @Test + public void testUpdateInteractionWithNullInteractionId() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map updateFields = new HashMap<>(); + updateFields.put("response", "updated response"); + + adapter.updateInteraction(null, updateFields, listener); + + // Verify that onFailure was called with IllegalArgumentException + verify(listener).onFailure(any(IllegalArgumentException.class)); + } + + @Test + public void testUpdateInteractionWithEmptyInteractionId() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map updateFields = new HashMap<>(); + updateFields.put("response", "updated response"); + + adapter.updateInteraction("", updateFields, listener); + + // Verify that onFailure was called with IllegalArgumentException + verify(listener).onFailure(any(IllegalArgumentException.class)); + } + + @Test + public void testUpdateInteractionWithNullUpdateFields() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + adapter.updateInteraction("interaction-id", null, listener); + + // Verify that onFailure was called with IllegalArgumentException + verify(listener).onFailure(any(IllegalArgumentException.class)); + } + + @Test + public void testUpdateInteractionWithEmptyUpdateFields() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map updateFields = new HashMap<>(); + + adapter.updateInteraction("interaction-id", updateFields, listener); + + // Verify that onFailure was called with IllegalArgumentException + verify(listener).onFailure(any(IllegalArgumentException.class)); + } + + @Test(expected = IllegalArgumentException.class) + public void testUpdateInteractionWithNullListener() { + Map updateFields = new HashMap<>(); + updateFields.put("response", "updated response"); + + adapter.updateInteraction("interaction-id", updateFields, null); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapterTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapterTest.java new file mode 100644 index 0000000000..990598dd42 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapterTest.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; + +import java.util.List; + +import org.junit.Test; +import org.opensearch.core.action.ActionListener; + +/** + * Unit tests for ChatMemoryAdapter interface default methods. + */ +public class ChatMemoryAdapterTest { + + /** + * Test implementation of ChatMemoryAdapter for testing default methods + */ + private static class TestChatMemoryAdapter implements ChatMemoryAdapter { + @Override + public void getMessages(ActionListener> listener) { + // Test implementation - not used in these tests + } + + @Override + public String getConversationId() { + return "test-conversation-id"; + } + + @Override + public String getMemoryContainerId() { + return "test-memory-container-id"; + } + } + + @Test + public void testSaveInteractionDefaultImplementation() { + TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Test that default implementation throws UnsupportedOperationException + adapter.saveInteraction("question", "response", "parentId", 1, "action", listener); + + // Verify that onFailure was called with UnsupportedOperationException + org.mockito.Mockito + .verify(listener) + .onFailure( + org.mockito.ArgumentMatchers + .argThat( + exception -> exception instanceof UnsupportedOperationException + && "Save not implemented".equals(exception.getMessage()) + ) + ); + } + + @Test + public void testUpdateInteractionDefaultImplementation() { + TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Test that default implementation throws UnsupportedOperationException + adapter.updateInteraction("interactionId", java.util.Map.of("key", "value"), listener); + + // Verify that onFailure was called with UnsupportedOperationException + org.mockito.Mockito + .verify(listener) + .onFailure( + org.mockito.ArgumentMatchers + .argThat( + exception -> exception instanceof UnsupportedOperationException + && "Update interaction not implemented".equals(exception.getMessage()) + ) + ); + } + + @Test + public void testSaveTraceDataDefaultImplementation() { + TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Test that default implementation throws UnsupportedOperationException + adapter.saveTraceData("toolName", "input", "output", "parentId", 1, "action", listener); + + // Verify that onFailure was called with UnsupportedOperationException + org.mockito.Mockito + .verify(listener) + .onFailure( + org.mockito.ArgumentMatchers + .argThat( + exception -> exception instanceof UnsupportedOperationException + && "Save trace data not implemented".equals(exception.getMessage()) + ) + ); + } + + @Test + public void testGetConversationId() { + TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); + assertEquals("test-conversation-id", adapter.getConversationId()); + } + + @Test + public void testGetMemoryContainerId() { + TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); + assertEquals("test-memory-container-id", adapter.getMemoryContainerId()); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6c3e3618e..57c472a4c4 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1171,4 +1171,136 @@ public void testConstructLLMParams_DefaultValues() { Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION)); Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE)); } + + @Test + public void testCreateMemoryAdapter_ConversationIndex() { + // Test that ConversationIndex memory type returns ConversationIndexMemory + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLMemorySpec memorySpec = MLMemorySpec.builder().type("conversation_index").build(); + MLAgent mlAgent = MLAgent + .builder() + .name("test_agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(memorySpec) + .build(); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.QUESTION, "test question"); + params.put(MLAgentExecutor.MEMORY_ID, "test_memory_id"); + + // Mock the memory factory + when(memoryMap.get("conversation_index")).thenReturn(memoryFactory); + + // Create a mock ConversationIndexMemory + org.opensearch.ml.engine.memory.ConversationIndexMemory mockMemory = Mockito + .mock(org.opensearch.ml.engine.memory.ConversationIndexMemory.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockMemory); + return null; + }).when(memoryFactory).create(anyString(), anyString(), anyString(), any()); + + // Test the createMemoryAdapter method + ActionListener testListener = new ActionListener() { + @Override + public void onResponse(Object result) { + // Verify that we get back a ConversationIndexMemory + assertTrue("Expected ConversationIndexMemory", result instanceof org.opensearch.ml.engine.memory.ConversationIndexMemory); + assertEquals("Memory should be the mocked instance", mockMemory, result); + } + + @Override + public void onFailure(Exception e) { + Assert.fail("Should not fail: " + e.getMessage()); + } + }; + + // This would normally be a private method call, but for testing we can verify the logic + // by checking that the correct memory type handling works through the public run method + // The actual test would need to be done through integration testing + } + + @Test + public void testCreateMemoryAdapter_AgenticMemory() { + // Test that agentic memory type returns AgenticMemoryAdapter + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLMemorySpec memorySpec = MLMemorySpec.builder().type("agentic_memory").build(); + MLAgent mlAgent = MLAgent + .builder() + .name("test_agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(memorySpec) + .build(); + + Map params = new HashMap<>(); + params.put("memory_container_id", "test_container_id"); + params.put("session_id", "test_session_id"); + params.put("owner_id", "test_owner_id"); + + // This test verifies that the agentic memory path would be taken + // Full integration testing would require mocking the agentic memory services + assertNotNull("MLAgent should be created successfully", mlAgent); + assertEquals("Memory type should be agentic_memory", "agentic_memory", mlAgent.getMemory().getType()); + } + + @Test + public void testEnhancedChatMessage() { + // Test the enhanced ChatMessage format + ChatMessage userMessage = ChatMessage + .builder() + .id("msg_1") + .timestamp(java.time.Instant.now()) + .sessionId("session_123") + .role("user") + .content("Hello, how are you?") + .contentType("text") + .origin("agentic_memory") + .metadata(Map.of("confidence", 0.95)) + .build(); + + ChatMessage assistantMessage = ChatMessage + .builder() + .id("msg_2") + .timestamp(java.time.Instant.now()) + .sessionId("session_123") + .role("assistant") + .content("I'm doing well, thank you!") + .contentType("text") + .origin("agentic_memory") + .metadata(Map.of("confidence", 0.98)) + .build(); + + // Verify the enhanced ChatMessage structure + assertEquals("user", userMessage.getRole()); + assertEquals("text", userMessage.getContentType()); + assertEquals("agentic_memory", userMessage.getOrigin()); + assertNotNull(userMessage.getMetadata()); + assertEquals(0.95, userMessage.getMetadata().get("confidence")); + + assertEquals("assistant", assistantMessage.getRole()); + assertEquals("I'm doing well, thank you!", assistantMessage.getContent()); + } + + @Test + public void testSimpleChatHistoryTemplateEngine() { + // Test the new template engine + SimpleChatHistoryTemplateEngine templateEngine = new SimpleChatHistoryTemplateEngine(); + + List messages = List + .of( + ChatMessage.builder().role("user").content("What's the weather?").contentType("text").build(), + ChatMessage.builder().role("assistant").content("It's sunny today!").contentType("text").build(), + ChatMessage.builder().role("system").content("Weather data retrieved from API").contentType("context").build() + ); + + String chatHistory = templateEngine.buildSimpleChatHistory(messages); + + assertNotNull("Chat history should not be null", chatHistory); + assertTrue("Should contain user message", chatHistory.contains("Human: What's the weather?")); + assertTrue("Should contain assistant message", chatHistory.contains("Assistant: It's sunny today!")); + assertTrue("Should contain system context", chatHistory.contains("[Context] Weather data retrieved from API")); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java index 4c3f6217af..4e819f103b 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java @@ -379,8 +379,9 @@ public SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuil public SearchSourceBuilder addOwnerIdFilter(User user, SearchSourceBuilder searchSourceBuilder) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.should(QueryBuilders.termsQuery(OWNER_ID_FIELD, user.getName())); - + if (user != null) { + boolQueryBuilder.should(QueryBuilders.termsQuery(OWNER_ID_FIELD, user.getName())); + } return applyFilterToSearchSource(searchSourceBuilder, boolQueryBuilder); } From 66860abbc843d00df319217e721d1e62312e133d Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 20 Oct 2025 14:56:13 -0700 Subject: [PATCH 02/58] support remote store for agentic memory Signed-off-by: Yaliang Wu --- .../org/opensearch/ml/common/CommonValue.java | 1 + .../common/connector/AbstractConnector.java | 7 +- .../ml/common/connector/ConnectorAction.java | 18 + .../ml/common/connector/HttpConnector.java | 12 +- .../memorycontainer/MemoryConfiguration.java | 30 +- .../MemoryContainerConstants.java | 1 + .../common/memorycontainer/RemoteStore.java | 150 ++++ .../memorycontainer/RemoteStoreType.java | 45 ++ .../ml/common/utils/StringUtils.java | 57 ++ .../index-mappings/ml_memory_container.json | 12 +- .../ml/common/connector/AwsConnectorTest.java | 1 + .../common/connector/ConnectorActionTest.java | 42 +- .../common/connector/HttpConnectorTest.java | 34 + .../MemoryConfigurationTests.java | 42 ++ .../memorycontainer/RemoteStoreTest.java | 77 ++ .../MLCreateConnectorInputTests.java | 1 + .../MLCreateConnectorRequestTests.java | 1 + .../ml/common/utils/StringUtilsTest.java | 161 ++++ .../remote/AwsConnectorExecutor.java | 10 +- .../remote/HttpJsonConnectorExecutor.java | 11 +- .../ml/engine/indices/MLIndicesHandler.java | 2 +- .../ExecuteConnectorTransportAction.java | 11 +- .../TransportCreateMemoryContainerAction.java | 136 +++- .../memory/MemorySearchService.java | 61 +- .../memory/TransportSearchMemoriesAction.java | 8 +- .../memory/TransportUpdateMemoryAction.java | 7 +- .../ml/helper/MemoryContainerHelper.java | 340 ++++++++- .../helper/MemoryContainerPipelineHelper.java | 109 +++ .../ml/helper/RemoteStorageHelper.java | 705 ++++++++++++++++++ .../ml/utils/MemorySearchQueryBuilder.java | 128 ++++ .../ml/helper/MemoryContainerHelperTests.java | 125 ++++ 31 files changed, 2275 insertions(+), 70 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java create mode 100644 common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStoreType.java create mode 100644 common/src/test/java/org/opensearch/ml/common/memorycontainer/RemoteStoreTest.java create mode 100644 plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 351171ede6..190a6790f3 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -113,6 +113,7 @@ public class CommonValue { public static final String CLIENT_CONFIG_FIELD = "client_config"; public static final String URL_FIELD = "url"; public static final String HEADERS_FIELD = "headers"; + public static final String CONNECTOR_ACTION_FIELD = "connector_action"; // MCP Constants public static final String MCP_TOOL_NAME_FIELD = "name"; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index 9a035230a0..05f2d3781b 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -121,8 +121,11 @@ public void parseResponse(T response, List modelTensors, boolea @Override public Optional findAction(String action) { - if (actions != null) { - return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst(); + if (actions != null && action != null) { + if (ConnectorAction.ActionType.isValidAction(action)) { + return actions.stream().filter(a -> a.getActionType().name().equalsIgnoreCase(action)).findFirst(); + } + return actions.stream().filter(a -> action.equals(a.getName())).findFirst(); } return Optional.empty(); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index c82f489296..3962e5a798 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -8,6 +8,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Locale; @@ -33,6 +34,7 @@ public class ConnectorAction implements ToXContentObject, Writeable { public static final String ACTION_TYPE_FIELD = "action_type"; + public static final String NAME_FIELD = "name"; public static final String METHOD_FIELD = "method"; public static final String URL_FIELD = "url"; public static final String HEADERS_FIELD = "headers"; @@ -52,6 +54,7 @@ public class ConnectorAction implements ToXContentObject, Writeable { private static final Logger logger = LogManager.getLogger(ConnectorAction.class); private ActionType actionType; + private String name; private String method; private String url; private Map headers; @@ -62,6 +65,7 @@ public class ConnectorAction implements ToXContentObject, Writeable { @Builder(toBuilder = true) public ConnectorAction( ActionType actionType, + String name, String method, String url, Map headers, @@ -78,7 +82,11 @@ public ConnectorAction( if (method == null) { throw new IllegalArgumentException("method can't be null"); } + if (name != null && ActionType.isValidAction(name)) { + throw new IllegalArgumentException("name can't be one of action type " + Arrays.toString(ActionType.values())); + } this.actionType = actionType; + this.name = name; this.method = method; this.url = url; this.headers = headers; @@ -97,6 +105,7 @@ public ConnectorAction(StreamInput input) throws IOException { this.requestBody = input.readOptionalString(); this.preProcessFunction = input.readOptionalString(); this.postProcessFunction = input.readOptionalString(); + this.name = input.readOptionalString();// TODO: add version check } @Override @@ -113,6 +122,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(requestBody); out.writeOptionalString(preProcessFunction); out.writeOptionalString(postProcessFunction); + out.writeOptionalString(name); // TODO: add version check } @Override @@ -139,6 +149,9 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params if (postProcessFunction != null) { builder.field(ACTION_POST_PROCESS_FUNCTION, postProcessFunction); } + if (name != null) { + builder.field(NAME_FIELD, name); + } return builder.endObject(); } @@ -149,6 +162,7 @@ public static ConnectorAction fromStream(StreamInput in) throws IOException { public static ConnectorAction parse(XContentParser parser) throws IOException { ActionType actionType = null; + String name = null; String method = null; String url = null; Map headers = null; @@ -165,6 +179,9 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { case ACTION_TYPE_FIELD: actionType = ActionType.valueOf(parser.text().toUpperCase(Locale.ROOT)); break; + case NAME_FIELD: + name = parser.text(); + break; case METHOD_FIELD: method = parser.text(); break; @@ -191,6 +208,7 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { return ConnectorAction .builder() .actionType(actionType) + .name(name) .method(method) .url(url) .headers(headers) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 53f66ce384..ecfb145444 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -13,6 +13,7 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.isJson; +import static org.opensearch.ml.common.utils.StringUtils.isJsonOrNdjson; import static org.opensearch.ml.common.utils.StringUtils.parseParameters; import java.io.IOException; @@ -358,12 +359,15 @@ public T createPayload(String action, Map parameters) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); - if (!isJson(payload)) { + if (!isJsonOrNdjson(payload)) { throw new IllegalArgumentException("Invalid payload: " + payload); } else if (neededStreamParameterInPayload(parameters)) { - JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject(); - jsonObject.addProperty("stream", true); - payload = jsonObject.toString(); + // Only add stream parameter for single JSON objects (not NDJSON) + if (isJson(payload)) { + JsonObject jsonObject = JsonParser.parseString(payload).getAsJsonObject(); + jsonObject.addProperty("stream", true); + payload = jsonObject.toString(); + } } return (T) payload; } diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryConfiguration.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryConfiguration.java index f328640577..41a89e2d26 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryConfiguration.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryConfiguration.java @@ -23,6 +23,7 @@ import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MAX_INFER_SIZE_LIMIT_ERROR; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_INDEX_PREFIX_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.PARAMETERS_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.REMOTE_STORE_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SEMANTIC_STORAGE_EMBEDDING_MODEL_ID_REQUIRED_ERROR; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SEMANTIC_STORAGE_EMBEDDING_MODEL_TYPE_REQUIRED_ERROR; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SPARSE_ENCODING_DIMENSION_NOT_ALLOWED_ERROR; @@ -85,6 +86,7 @@ public class MemoryConfiguration implements ToXContentObject, Writeable { @Builder.Default private boolean useSystemIndex = true; private String tenantId; + private RemoteStore remoteStore; public MemoryConfiguration( String indexPrefix, @@ -99,7 +101,8 @@ public MemoryConfiguration( boolean disableHistory, boolean disableSession, boolean useSystemIndex, - String tenantId + String tenantId, + RemoteStore remoteStore ) { // Validate first validateInputs(embeddingModelType, embeddingModelId, dimension, maxInferSize); @@ -127,6 +130,7 @@ public MemoryConfiguration( this.disableSession = disableSession; this.useSystemIndex = useSystemIndex; this.tenantId = tenantId; + this.remoteStore = remoteStore; } private String buildIndexPrefix(String indexPrefix, boolean useSystemIndex) { @@ -168,6 +172,9 @@ public MemoryConfiguration(StreamInput input) throws IOException { this.disableSession = input.readBoolean(); this.useSystemIndex = input.readBoolean(); this.tenantId = input.readOptionalString(); + if (input.readBoolean()) { + this.remoteStore = new RemoteStore(input); + } } @Override @@ -200,6 +207,12 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(disableSession); out.writeBoolean(useSystemIndex); out.writeOptionalString(tenantId); + if (remoteStore != null) { + out.writeBoolean(true); + remoteStore.writeTo(out); + } else { + out.writeBoolean(false); + } } @Override @@ -250,6 +263,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (tenantId != null) { builder.field(TENANT_ID_FIELD, tenantId); } + if (remoteStore != null) { + builder.field(REMOTE_STORE_FIELD, remoteStore); + } builder.endObject(); return builder; } @@ -268,6 +284,7 @@ public static MemoryConfiguration parse(XContentParser parser) throws IOExceptio boolean disableSession = true; boolean useSystemIndex = true; String tenantId = null; + RemoteStore remoteStore = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -319,6 +336,9 @@ public static MemoryConfiguration parse(XContentParser parser) throws IOExceptio case USE_SYSTEM_INDEX_FIELD: useSystemIndex = parser.booleanValue(); break; + case REMOTE_STORE_FIELD: + remoteStore = RemoteStore.parse(parser); + break; default: parser.skipChildren(); break; @@ -341,6 +361,7 @@ public static MemoryConfiguration parse(XContentParser parser) throws IOExceptio .disableSession(disableSession) .useSystemIndex(useSystemIndex) .tenantId(tenantId) + .remoteStore(remoteStore) .build(); } @@ -474,6 +495,10 @@ public static void validateStrategiesRequireModels(MemoryConfiguration config) { boolean hasLlm = config.getLlmId() != null; boolean hasEmbedding = config.getEmbeddingModelId() != null && config.getEmbeddingModelType() != null; + if (config.getRemoteStore() != null) { + hasEmbedding = config.getRemoteStore().getEmbeddingModelId() != null && config.getRemoteStore().getEmbeddingModelId() != null; + } + if (!hasLlm || !hasEmbedding) { String missing = !hasLlm && !hasEmbedding ? "LLM model and embedding model" : !hasLlm ? "LLM model (llm_id)" @@ -526,6 +551,9 @@ public void update(MemoryConfiguration updateContent) { // Only update dimension for TEXT_EMBEDDING if provided this.dimension = updateContent.getDimension(); } + if (updateContent.getRemoteStore() != null) { + this.remoteStore = updateContent.getRemoteStore(); + } // Note: indexPrefix and other structural fields are intentionally not updated // as they would require index recreation } diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java index b514bac737..4253150bba 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java @@ -39,6 +39,7 @@ public class MemoryContainerConstants { public static final String PARAMETERS_FIELD = "parameters"; public static final String ID_FIELD = "id"; public static final String ENABLED_FIELD = "enabled"; + public static final String REMOTE_STORE_FIELD = "remote_store"; // Default values public static final int MAX_INFER_SIZE_DEFAULT_VALUE = 5; diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java new file mode 100644 index 0000000000..9f140b5505 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.DIMENSION_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.EMBEDDING_MODEL_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.EMBEDDING_MODEL_TYPE_FIELD; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; + +/** + * Remote store configuration for storing memory in remote locations like AWS OpenSearch Serverless + */ +@Data +@EqualsAndHashCode +public class RemoteStore implements ToXContentObject, Writeable { + + public static final String TYPE_FIELD = "type"; + public static final String CONNECTOR_ID_FIELD = "connector_id"; + + private RemoteStoreType type; + private String connectorId; + private FunctionName embeddingModelType; + private String embeddingModelId; + private Integer embeddingDimension; + + @Builder + public RemoteStore( + RemoteStoreType type, + String connectorId, + FunctionName embeddingModelType, + String embeddingModelId, + Integer embeddingDimension + ) { + if (type == null) { + throw new IllegalArgumentException("Invalid remote store type"); + } + this.type = type; + this.connectorId = connectorId; + this.embeddingModelType = embeddingModelType; + this.embeddingModelId = embeddingModelId; + this.embeddingDimension = embeddingDimension; + } + + public RemoteStore(StreamInput input) throws IOException { + this.type = input.readEnum(RemoteStoreType.class); + this.connectorId = input.readOptionalString(); + if (input.readOptionalBoolean()) { + this.embeddingModelType = input.readEnum(FunctionName.class); + } + this.embeddingModelId = input.readOptionalString(); + this.embeddingDimension = input.readOptionalInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(type); + out.writeOptionalString(connectorId); + if (embeddingModelType != null) { + out.writeBoolean(true); + out.writeEnum(embeddingModelType); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(embeddingModelId); + out.writeOptionalInt(embeddingDimension); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (type != null) { + builder.field(TYPE_FIELD, type); + } + if (connectorId != null) { + builder.field(CONNECTOR_ID_FIELD, connectorId); + } + if (embeddingModelType != null) { + builder.field(EMBEDDING_MODEL_TYPE_FIELD, embeddingModelType); + } + if (embeddingModelId != null) { + builder.field(EMBEDDING_MODEL_ID_FIELD, embeddingModelId); + } + if (embeddingDimension != null) { + builder.field(DIMENSION_FIELD, embeddingDimension); + } + builder.endObject(); + return builder; + } + + public static RemoteStore parse(XContentParser parser) throws IOException { + RemoteStoreType type = null; + String connectorId = null; + FunctionName embeddingModelType = null; + String embeddingModelId = null; + Integer embeddingDimension = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TYPE_FIELD: + type = RemoteStoreType.fromString(parser.text()); + break; + case CONNECTOR_ID_FIELD: + connectorId = parser.text(); + break; + case EMBEDDING_MODEL_TYPE_FIELD: + embeddingModelType = FunctionName.from(parser.text()); + break; + case EMBEDDING_MODEL_ID_FIELD: + embeddingModelId = parser.text(); + break; + case DIMENSION_FIELD: + embeddingDimension = parser.intValue(); + break; + default: + parser.skipChildren(); + break; + } + } + + return RemoteStore + .builder() + .type(type) + .connectorId(connectorId) + .embeddingModelType(embeddingModelType) + .embeddingModelId(embeddingModelId) + .embeddingDimension(embeddingDimension) + .build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStoreType.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStoreType.java new file mode 100644 index 0000000000..7c97a1ffea --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStoreType.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import java.util.Arrays; + +public enum RemoteStoreType { + OPENSEARCH("opensearch"), + AOSS("aoss"); + + private final String value; + + RemoteStoreType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + public static RemoteStoreType fromString(String value) { + if (value == null) { + return null; + } + + String normalizedValue = value.toLowerCase(); + for (RemoteStoreType type : RemoteStoreType.values()) { + if (type.value.equalsIgnoreCase(normalizedValue)) { + return type; + } + } + + throw new IllegalArgumentException( + "Invalid memory type: " + value + ". Must be one of: " + Arrays.toString(RemoteStoreType.values()) + ); + } + + @Override + public String toString() { + return value; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index a3f1a3b416..5e697fbcef 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -136,6 +136,40 @@ public static boolean isJson(String json) { } } + /** + * Checks if the given string is valid JSON or NDJSON (newline-delimited JSON). + * NDJSON is commonly used for bulk operations in OpenSearch where each line is a separate JSON object. + * + * @param json the string to validate + * @return true if the string is valid JSON or NDJSON, false otherwise + */ + public static boolean isJsonOrNdjson(String json) { + if (json == null || json.isBlank()) { + return false; + } + + // First check if it's regular JSON + if (isJson(json)) { + return true; + } + + // Check if it's NDJSON (newline-delimited JSON) + String[] lines = json.split("\\r?\\n"); + if (lines.length == 0) { + return false; + } + + // Each non-empty line must be valid JSON + for (String line : lines) { + String trimmedLine = line.trim(); + if (!trimmedLine.isEmpty() && !isJson(trimmedLine)) { + return false; + } + } + + return true; + } + /** * Ensures that a string is properly JSON escaped. * @@ -302,6 +336,29 @@ public static String toJson(Object value) { } } + /** + * Converts an object to JSON string using plain number formatting (no scientific notation). + * This is particularly useful for serializing documents with timestamp fields that need to be + * sent to remote storage systems that expect epoch milliseconds as plain long integers. + * + * @param value the object to convert to JSON + * @return JSON string representation with plain number formatting + */ + @SuppressWarnings("removal") + public static String toJsonWithPlainNumbers(Object value) { + try { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + if (value instanceof String) { + return (String) value; + } else { + return PLAIN_NUMBER_GSON.toJson(value); + } + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } + @SuppressWarnings("removal") public static Map convertScriptStringToJsonString(Map processedInput) { Map parameterStringMap = new HashMap<>(); diff --git a/common/src/main/resources/index-mappings/ml_memory_container.json b/common/src/main/resources/index-mappings/ml_memory_container.json index a9a0162571..a39f7e0e24 100644 --- a/common/src/main/resources/index-mappings/ml_memory_container.json +++ b/common/src/main/resources/index-mappings/ml_memory_container.json @@ -1,6 +1,6 @@ { "_meta": { - "schema_version": 1 + "schema_version": 2 }, "properties": { "name": { @@ -53,6 +53,16 @@ "max_infer_size": { "type": "integer" }, + "remote_store": { + "properties": { + "type": { + "type": "keyword" + }, + "connector_id": { + "type": "keyword" + } + } + }, "strategies": { "type": "nested", "properties": { diff --git a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java index 2b679b8bbe..92dd038dbd 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java @@ -222,6 +222,7 @@ private AwsConnector createAwsConnector(Map parameters, Map new ConnectorAction(null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null) + () -> new ConnectorAction(null, null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null) ); assertEquals("action type can't be null", exception.getMessage()); @@ -109,7 +109,7 @@ public void constructor_NullActionType() { public void constructor_NullUrl() { Throwable exception = assertThrows( IllegalArgumentException.class, - () -> new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null) + () -> new ConnectorAction(TEST_ACTION_TYPE, null, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null) ); assertEquals("url can't be null", exception.getMessage()); } @@ -118,14 +118,23 @@ public void constructor_NullUrl() { public void constructor_NullMethod() { Throwable exception = assertThrows( IllegalArgumentException.class, - () -> new ConnectorAction(TEST_ACTION_TYPE, null, URL, null, TEST_REQUEST_BODY, null, null) + () -> new ConnectorAction(TEST_ACTION_TYPE, null, null, URL, null, TEST_REQUEST_BODY, null, null) ); assertEquals("method can't be null", exception.getMessage()); } @Test public void testValidatePrePostProcessFunctionsWithNullPreProcessFunctionSuccess() { - ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, OPENAI_URL, null, TEST_REQUEST_BODY, null, null); + ConnectorAction action = new ConnectorAction( + TEST_ACTION_TYPE, + null, + TEST_METHOD_HTTP, + OPENAI_URL, + null, + TEST_REQUEST_BODY, + null, + null + ); action.validatePrePostProcessFunctions(Map.of()); assertFalse(testAppender.getLogEvents().stream().anyMatch(event -> event.getLevel() == Level.WARN)); } @@ -134,6 +143,7 @@ public void testValidatePrePostProcessFunctionsWithNullPreProcessFunctionSuccess public void testValidatePrePostProcessFunctionsWithExternalServers() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, URL, null, @@ -151,6 +161,7 @@ public void testValidatePrePostProcessFunctionsWithCustomPainlessScriptPreProces "\"\\n StringBuilder builder = new StringBuilder();\\n builder.append(\\\"\\\\\\\"\\\");\\n String first = params.text_docs[0];\\n builder.append(first);\\n builder.append(\\\"\\\\\\\"\\\");\\n def parameters = \\\"{\\\" +\\\"\\\\\\\"text_inputs\\\\\\\":\\\" + builder + \\\"}\\\";\\n return \\\"{\\\" +\\\"\\\\\\\"parameters\\\\\\\":\\\" + parameters + \\\"}\\\";\""; ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, OPENAI_URL, null, @@ -166,6 +177,7 @@ public void testValidatePrePostProcessFunctionsWithCustomPainlessScriptPreProces public void testValidatePrePostProcessFunctionsWithOpenAIConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, "https://${parameters.endpoint}/v1/chat/completions", null, @@ -181,6 +193,7 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorCorrectInBuilt public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, OPENAI_URL, null, @@ -206,6 +219,7 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPr public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, OPENAI_URL, null, @@ -231,6 +245,7 @@ public void testValidatePrePostProcessFunctionsWithOpenAIConnectorWrongInBuiltPo public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -243,6 +258,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -255,6 +271,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -270,6 +287,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorCorrectInBuilt public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -295,6 +313,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPr public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, COHERE_URL, null, @@ -320,6 +339,7 @@ public void testValidatePrePostProcessFunctionsWithCohereConnectorWrongInBuiltPo public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -332,6 +352,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuil action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -344,6 +365,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuil action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -359,6 +381,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorCorrectInBuil public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -384,6 +407,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, BEDROCK_URL, null, @@ -409,6 +433,7 @@ public void testValidatePrePostProcessFunctionsWithBedrockConnectorWrongInBuiltP public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWithCorrectInBuiltPrePostProcessFunctionSuccess() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -421,6 +446,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWithCorrect action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -436,6 +462,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWithCorrect public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuiltPreProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -463,6 +490,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuil public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuiltPostProcessFunction() { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, SAGEMAKER_URL, null, @@ -488,7 +516,7 @@ public void testValidatePrePostProcessFunctionsWithSagemakerConnectorWrongInBuil @Test public void writeTo_NullValue() throws IOException { - ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, null, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); @@ -504,6 +532,7 @@ public void writeTo() throws IOException { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, URL, headers, @@ -519,7 +548,7 @@ public void writeTo() throws IOException { @Test public void toXContent_NullValue() throws IOException { - ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); + ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, null, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -540,6 +569,7 @@ public void toXContent() throws IOException { ConnectorAction action = new ConnectorAction( TEST_ACTION_TYPE, + null, TEST_METHOD_HTTP, URL, headers, diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 1038006f2c..ad0f8c1252 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -274,6 +274,39 @@ public void createPayload_WithStreamParameter_UnsupportedInterface() { Assert.assertEquals("{\"input\": \"Hello world\"}", payload); } + @Test + public void createPayload_NdjsonFormat() { + // Test NDJSON format (newline-delimited JSON) commonly used for bulk operations + String requestBody = "{\"index\": {\"_index\": \"${parameters.index}\"}}\n{\"field\": \"${parameters.value}\"}"; + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + + Map parameters = new HashMap<>(); + parameters.put("index", "test_index"); + parameters.put("value", "test_value"); + + String payload = connector.createPayload(PREDICT.name(), parameters); + + Assert.assertEquals("{\"index\": {\"_index\": \"test_index\"}}\n{\"field\": \"test_value\"}", payload); + } + + @Test + public void createPayload_NdjsonFormat_WithStreamParameter() { + // Test that stream parameter is not added to NDJSON payloads + String requestBody = "{\"index\": {\"_index\": \"${parameters.index}\"}}\n{\"field\": \"${parameters.value}\"}"; + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + + Map parameters = new HashMap<>(); + parameters.put("index", "test_index"); + parameters.put("value", "test_value"); + parameters.put("stream", "true"); + parameters.put("_llm_interface", "openai/v1/chat/completions"); + + String payload = connector.createPayload(PREDICT.name(), parameters); + + // Stream parameter should not be added to NDJSON format + Assert.assertEquals("{\"index\": {\"_index\": \"test_index\"}}\n{\"field\": \"test_value\"}", payload); + } + @Test public void parseResponse_modelTensorJson() throws IOException { HttpConnector connector = createHttpConnector(); @@ -388,6 +421,7 @@ public static HttpConnector createHttpConnectorWithRequestBody(String requestBod ConnectorAction action = new ConnectorAction( actionType, + null, method, url, headers, diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryConfigurationTests.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryConfigurationTests.java index 236adf0f43..c48a68f933 100644 --- a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryConfigurationTests.java +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryConfigurationTests.java @@ -1080,4 +1080,46 @@ public void testBuildIndexPrefix_AcceptsHyphenAndUnderscore() { MemoryConfiguration config = MemoryConfiguration.builder().indexPrefix("valid_prefix-with-chars").build(); assertEquals("valid_prefix-with-chars", config.getIndexPrefix()); } + + // ==================== RemoteStore Tests ==================== + + @Test + public void testMemoryConfiguration_WithRemoteStore() { + RemoteStore remoteStore = RemoteStore.builder().type("aoss").connectorId("ySf08JkBym-3qj1O2uub").build(); + + MemoryConfiguration config = MemoryConfiguration + .builder() + .indexPrefix("test") + .useSystemIndex(false) + .remoteStore(remoteStore) + .build(); + + assertNotNull(config.getRemoteStore()); + assertEquals("aoss", config.getRemoteStore().getType()); + assertEquals("ySf08JkBym-3qj1O2uub", config.getRemoteStore().getConnectorId()); + } + + @Test + public void testMemoryConfiguration_WithoutRemoteStore() { + MemoryConfiguration config = MemoryConfiguration.builder().indexPrefix("test").useSystemIndex(false).build(); + + assertNull(config.getRemoteStore()); + } + + @Test + public void testMemoryConfiguration_UpdateRemoteStore() { + MemoryConfiguration config = MemoryConfiguration.builder().indexPrefix("test").useSystemIndex(false).build(); + + assertNull(config.getRemoteStore()); + + RemoteStore remoteStore = RemoteStore.builder().type("aoss").connectorId("ySf08JkBym-3qj1O2uub").build(); + + MemoryConfiguration updateContent = MemoryConfiguration.builder().remoteStore(remoteStore).build(); + + config.update(updateContent); + + assertNotNull(config.getRemoteStore()); + assertEquals("aoss", config.getRemoteStore().getType()); + assertEquals("ySf08JkBym-3qj1O2uub", config.getRemoteStore().getConnectorId()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/RemoteStoreTest.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/RemoteStoreTest.java new file mode 100644 index 0000000000..134277c851 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/RemoteStoreTest.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +public class RemoteStoreTest { + + @Test + public void testRemoteStoreConstruction() { + RemoteStore remoteStore = RemoteStore.builder().type("aoss").connectorId("ySf08JkBym-3qj1O2uub").build(); + + assertNotNull(remoteStore); + assertEquals("aoss", remoteStore.getType()); + assertEquals("ySf08JkBym-3qj1O2uub", remoteStore.getConnectorId()); + } + + @Test + public void testRemoteStoreToXContent() throws IOException { + RemoteStore remoteStore = RemoteStore.builder().type("aoss").connectorId("ySf08JkBym-3qj1O2uub").build(); + + XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); + remoteStore.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + + assertNotNull(jsonStr); + assert (jsonStr.contains("\"type\":\"aoss\"")); + assert (jsonStr.contains("\"connector_id\":\"ySf08JkBym-3qj1O2uub\"")); + } + + @Test + public void testRemoteStoreParse() throws IOException { + String json = "{\"type\":\"aoss\",\"connector_id\":\"ySf08JkBym-3qj1O2uub\"}"; + + XContentParser parser = createParser(json); + parser.nextToken(); + RemoteStore remoteStore = RemoteStore.parse(parser); + + assertNotNull(remoteStore); + assertEquals("aoss", remoteStore.getType()); + assertEquals("ySf08JkBym-3qj1O2uub", remoteStore.getConnectorId()); + } + + @Test + public void testRemoteStoreSerialization() throws IOException { + RemoteStore remoteStore = RemoteStore.builder().type("aoss").connectorId("ySf08JkBym-3qj1O2uub").build(); + + BytesStreamOutput output = new BytesStreamOutput(); + remoteStore.writeTo(output); + + StreamInput input = output.bytes().streamInput(); + RemoteStore deserializedRemoteStore = new RemoteStore(input); + + assertEquals(remoteStore.getType(), deserializedRemoteStore.getType()); + assertEquals(remoteStore.getConnectorId(), deserializedRemoteStore.getConnectorId()); + } + + private XContentParser createParser(String json) throws IOException { + XContentParser parser = XContentType.JSON.xContent().createParser(null, null, json); + return parser; + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index a7df00618a..dc5d4758a4 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -82,6 +82,7 @@ public void setUp() { String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( actionType, + null, method, url, headers, diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java index b4f7629689..40fb7125d3 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java @@ -48,6 +48,7 @@ public void setUp() { String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction( actionType, + null, method, url, headers, diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index e81ccc54a3..559f1adf6b 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -31,6 +31,12 @@ import org.junit.Test; import org.opensearch.OpenSearchParseException; import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; @@ -77,6 +83,161 @@ public void isJson_False() { assertFalse(StringUtils.isJson("[abc\n123]")); } + @Test + public void isJsonOrNdjson_ValidJson() { + // Regular JSON should still work + assertTrue(StringUtils.isJsonOrNdjson("{}")); + assertTrue(StringUtils.isJsonOrNdjson("[]")); + assertTrue(StringUtils.isJsonOrNdjson("{\"key\": \"value\"}")); + assertTrue(StringUtils.isJsonOrNdjson("[1, 2, 3]")); + } + + @Test + public void isJsonOrNdjson_ValidNdjson() { + // NDJSON format (newline-delimited JSON) + assertTrue(StringUtils.isJsonOrNdjson("{\"index\": {\"_index\": \"test\"}}\n{\"field\": \"value\"}")); + assertTrue(StringUtils.isJsonOrNdjson("{\"a\": 1}\n{\"b\": 2}\n{\"c\": 3}")); + assertTrue(StringUtils.isJsonOrNdjson("[1, 2]\n[3, 4]")); + + // NDJSON with empty lines should be valid + assertTrue(StringUtils.isJsonOrNdjson("{\"a\": 1}\n\n{\"b\": 2}")); + + // Single line NDJSON + assertTrue(StringUtils.isJsonOrNdjson("{\"single\": \"line\"}")); + } + + @Test + public void isJsonOrNdjson_Invalid() { + // Invalid JSON + assertFalse(StringUtils.isJsonOrNdjson("{")); + assertFalse(StringUtils.isJsonOrNdjson("[")); + assertFalse(StringUtils.isJsonOrNdjson("{\"key\": \"value}")); + + // NDJSON with invalid JSON line + assertFalse(StringUtils.isJsonOrNdjson("{\"valid\": \"json\"}\n{invalid json}")); + assertFalse(StringUtils.isJsonOrNdjson("{\"a\": 1}\n{\"b\": 2\n{\"c\": 3}")); + + // Null and blank + assertFalse(StringUtils.isJsonOrNdjson(null)); + assertFalse(StringUtils.isJsonOrNdjson("")); + assertFalse(StringUtils.isJsonOrNdjson(" ")); + } + + @Test + public void isJsonOrNdjson_InvalidJson() { + String json = + "{\"index\":{\"_index\":\"demo2-memory-long-term\"}}\n{\"created_time\":1760760969707,\"memory\":\"Bob likes swimming.\",\"last_updated_time\":1760760969707,\"namespace_size\":1,\"owner_id\":\"admin\",\"namespace\":{\"user_id\":\"bob\"},\"strategy_id\":\"semantic_9766b0fe\",\"strategy_type\":\"SEMANTIC\",\"memory_container_id\":\"B4DU85kBZsSZwpNve_T0\",\"tags\":{\"topic\":\"personal info\"}}\n"; + assertTrue(StringUtils.isJsonOrNdjson(json)); + } + + private String convertBulkRequestToNDJSON(BulkRequest bulkRequest) { + StringBuilder ndjson = new StringBuilder(); + + for (var docWriteRequest : bulkRequest.requests()) { + if (docWriteRequest instanceof IndexRequest) { + IndexRequest indexRequest = (IndexRequest) docWriteRequest; + + // Action line + Map actionLine = new HashMap<>(); + actionLine.put("index", Map.of("_index", indexRequest.index())); + ndjson.append(StringUtils.toJson(actionLine)).append('\n'); + + // Document line + ndjson.append(indexRequest.source().utf8ToString()).append('\n'); + } + } + + return ndjson.toString(); + } + + private String convertBulkRequestToNDJSON1(BulkRequest bulkRequest) { + StringBuilder ndjson = new StringBuilder(); + + for (var docWriteRequest : bulkRequest.requests()) { + if (docWriteRequest instanceof IndexRequest) { + IndexRequest indexRequest = (IndexRequest) docWriteRequest; + + // Action line for index operation + Map actionMetadata = new HashMap<>(); + actionMetadata.put("_index", indexRequest.index()); + if (indexRequest.id() != null) { + actionMetadata.put("_id", indexRequest.id()); + } + Map actionLine = new HashMap<>(); + actionLine.put("index", actionMetadata); + ndjson.append(StringUtils.toJson(actionLine)).append('\n'); + + // Document line + ndjson.append(indexRequest.source().utf8ToString()).append('\n'); + + } else if (docWriteRequest instanceof UpdateRequest) { + UpdateRequest updateRequest = (UpdateRequest) docWriteRequest; + + // Action line for update operation + Map actionMetadata = new HashMap<>(); + actionMetadata.put("_index", updateRequest.index()); + actionMetadata.put("_id", updateRequest.id()); + Map actionLine = new HashMap<>(); + actionLine.put("update", actionMetadata); + ndjson.append(StringUtils.toJson(actionLine)).append('\n'); + + // Document line - for update, we need to wrap in "doc" or "script" + Map updateDoc = new HashMap<>(); + if (updateRequest.doc() != null) { + updateDoc.put("doc", XContentHelper.convertToMap(updateRequest.doc().source(), false, XContentType.JSON).v2()); + if (updateRequest.docAsUpsert()) { + updateDoc.put("doc_as_upsert", true); + } + } else if (updateRequest.script() != null) { + updateDoc.put("script", updateRequest.script()); + } + if (updateRequest.upsertRequest() != null) { + updateDoc + .put("upsert", XContentHelper.convertToMap(updateRequest.upsertRequest().source(), false, XContentType.JSON).v2()); + } + ndjson.append(StringUtils.toJson(updateDoc)).append('\n'); + + } else if (docWriteRequest instanceof DeleteRequest) { + DeleteRequest deleteRequest = (DeleteRequest) docWriteRequest; + + // Action line for delete operation + Map actionMetadata = new HashMap<>(); + actionMetadata.put("_index", deleteRequest.index()); + actionMetadata.put("_id", deleteRequest.id()); + Map actionLine = new HashMap<>(); + actionLine.put("delete", actionMetadata); + ndjson.append(StringUtils.toJson(actionLine)).append('\n'); + + // Delete operations don't have a document line, just the action line + } + } + + return ndjson.toString(); + } + + @Test + public void test1() { + BulkRequest bulkRequest = new BulkRequest(); + IndexRequest request = new IndexRequest(); + request.index("test"); + request.source(Map.of("key", "value")); + bulkRequest.add(request); + + DeleteRequest deleteRequest = new DeleteRequest(); + deleteRequest.index("test"); + deleteRequest.id("test11"); + bulkRequest.add(deleteRequest); + + UpdateRequest updateRequest = new UpdateRequest(); + updateRequest.index("test"); + updateRequest.id("test123"); + updateRequest.doc(Map.of("key", "value")); + bulkRequest.add(updateRequest); + + String s = convertBulkRequestToNDJSON1(bulkRequest); + System.out.println(s); + } + @Test public void toUTF8() { String rawString = "\uD83D\uDE00\uD83D\uDE0D\uD83D\uDE1C"; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 3b53935aaf..f500ae32d1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -9,8 +9,10 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; +import static software.amazon.awssdk.http.SdkHttpMethod.DELETE; import static software.amazon.awssdk.http.SdkHttpMethod.GET; import static software.amazon.awssdk.http.SdkHttpMethod.POST; +import static software.amazon.awssdk.http.SdkHttpMethod.PUT; import java.security.AccessController; import java.security.PrivilegedExceptionAction; @@ -110,7 +112,13 @@ public void invokeRemoteService( request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST); break; case "GET": - request = ConnectorUtils.buildSdkRequest(action, connector, parameters, null, GET); + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, GET); + break; + case "PUT": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, PUT); + break; + case "DELETE": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, DELETE); break; default: throw new IllegalArgumentException("unsupported http method"); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 7804770258..7c6e89e076 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -8,9 +8,12 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; +import static software.amazon.awssdk.http.SdkHttpMethod.DELETE; import static software.amazon.awssdk.http.SdkHttpMethod.GET; import static software.amazon.awssdk.http.SdkHttpMethod.POST; +import static software.amazon.awssdk.http.SdkHttpMethod.PUT; +import java.net.URL; import java.security.AccessController; import java.security.PrivilegedExceptionAction; import java.time.Duration; @@ -109,7 +112,13 @@ public void invokeRemoteService( request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, POST); break; case "GET": - request = ConnectorUtils.buildSdkRequest(action, connector, parameters, null, GET); + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, GET); + break; + case "PUT": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, PUT); + break; + case "DELETE": + request = ConnectorUtils.buildSdkRequest(action, connector, parameters, payload, DELETE); break; default: throw new IllegalArgumentException("unsupported http method"); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index daad43eee1..2b7e428ab6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -159,7 +159,7 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) initIndexIfAbsent(indexName, mapping, version, listener); } - private String getMapping(String mappingPath) { + public String getMapping(String mappingPath) { if (mappingPath == null) { throw new IllegalArgumentException("Mapping path cannot be null"); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java index c1b5db1778..cb5fe126ab 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java @@ -5,6 +5,7 @@ package org.opensearch.ml.action.connector; +import static org.opensearch.ml.common.CommonValue.CONNECTOR_ACTION_FIELD; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import org.opensearch.ResourceNotFoundException; @@ -18,6 +19,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; @@ -73,14 +75,19 @@ public ExecuteConnectorTransportAction( protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.fromActionRequest(request); String connectorId = executeConnectorRequest.getConnectorId(); + RemoteInferenceInputDataSet inputDataset = (RemoteInferenceInputDataSet) executeConnectorRequest.getMlInput().getInputDataset(); String connectorAction = ConnectorAction.ActionType.EXECUTE.name(); + if (inputDataset.getParameters() != null && inputDataset.getParameters().get(CONNECTOR_ACTION_FIELD) != null) { + connectorAction = inputDataset.getParameters().get(CONNECTOR_ACTION_FIELD); + } if (MLIndicesHandler .doesMultiTenantIndexExist(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), ML_CONNECTOR_INDEX)) { + String finalConnectorAction = connectorAction; ActionListener listener = ActionListener.wrap(connector -> { if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { // adding tenantID as null, because we are not implement multi-tenancy for this feature yet. - connector.decrypt(connectorAction, (credential, tenantId) -> encryptor.decrypt(credential, null), null); + connector.decrypt(finalConnectorAction, (credential, tenantId) -> encryptor.decrypt(credential, null), null); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader .initInstance(connector.getProtocol(), connector, Connector.class); connectorExecutor.setConnectorPrivateIpEnabled(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()); @@ -89,7 +96,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + .executeAction(finalConnectorAction, executeConnectorRequest.getMlInput(), ActionListener.wrap(taskResponse -> { actionListener.onResponse(taskResponse); }, e -> { actionListener.onFailure(e); })); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java index a6118c9261..a06bbdf010 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java @@ -34,6 +34,7 @@ import org.opensearch.ml.helper.MemoryContainerModelValidator; import org.opensearch.ml.helper.MemoryContainerPipelineHelper; import org.opensearch.ml.helper.MemoryContainerSharedIndexValidator; +import org.opensearch.ml.helper.RemoteStorageHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.ml.utils.TenantAwareHelper; @@ -163,6 +164,9 @@ private void createMemoryDataIndices(MLMemoryContainer container, User user, Act MemoryConfiguration configuration = container.getConfiguration(); String indexPrefix = configuration != null ? configuration.getIndexPrefix() : null; + // Check if remote store is configured + boolean useRemoteStore = configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null; + // Convert to lowercase as OpenSearch doesn't support uppercase in index names final String sessionIndexName = configuration.getSessionIndexName(); final String workingMemoryIndexName = configuration.getWorkingMemoryIndexName(); @@ -173,32 +177,43 @@ private void createMemoryDataIndices(MLMemoryContainer container, User user, Act // No strategies = 2 indices (session/working only) if (configuration.getStrategies() == null || configuration.getStrategies().isEmpty()) { if (configuration.isDisableSession()) { - mlIndicesHandler.createWorkingMemoryDataIndex(workingMemoryIndexName, configuration, ActionListener.wrap(success -> { - // Return the actual index name that was created - // Create the memory data index with appropriate mapping - listener.onResponse(workingMemoryIndexName); - }, listener::onFailure)); - } else { - mlIndicesHandler.createSessionMemoryDataIndex(sessionIndexName, configuration, ActionListener.wrap(result -> { + if (useRemoteStore) { + createRemoteWorkingMemoryIndex(configuration, workingMemoryIndexName, ActionListener.wrap(success -> { + listener.onResponse(workingMemoryIndexName); + }, listener::onFailure)); + } else { mlIndicesHandler.createWorkingMemoryDataIndex(workingMemoryIndexName, configuration, ActionListener.wrap(success -> { - // Return the actual index name that was created - // Create the memory data index with appropriate mapping listener.onResponse(workingMemoryIndexName); }, listener::onFailure)); - }, listener::onFailure)); + } + } else { + if (useRemoteStore) { + createRemoteSessionMemoryIndex(configuration, sessionIndexName, ActionListener.wrap(result -> { + createRemoteWorkingMemoryIndex(configuration, workingMemoryIndexName, ActionListener.wrap(success -> { + listener.onResponse(workingMemoryIndexName); + }, listener::onFailure)); + }, listener::onFailure)); + } else { + mlIndicesHandler.createSessionMemoryDataIndex(sessionIndexName, configuration, ActionListener.wrap(result -> { + mlIndicesHandler + .createWorkingMemoryDataIndex(workingMemoryIndexName, configuration, ActionListener.wrap(success -> { + listener.onResponse(workingMemoryIndexName); + }, listener::onFailure)); + }, listener::onFailure)); + } } } else { if (configuration.isDisableSession()) { - createMemoryIndexes( - container, - listener, - configuration, - workingMemoryIndexName, - longTermMemoryIndexName, - longTermMemoryHistoryIndexName - ); - } else { - mlIndicesHandler.createSessionMemoryDataIndex(sessionIndexName, configuration, ActionListener.wrap(result -> { + if (useRemoteStore) { + createRemoteMemoryIndexes( + container, + listener, + configuration, + workingMemoryIndexName, + longTermMemoryIndexName, + longTermMemoryHistoryIndexName + ); + } else { createMemoryIndexes( container, listener, @@ -207,9 +222,32 @@ private void createMemoryDataIndices(MLMemoryContainer container, User user, Act longTermMemoryIndexName, longTermMemoryHistoryIndexName ); - }, listener::onFailure)); + } + } else { + if (useRemoteStore) { + createRemoteSessionMemoryIndex(configuration, sessionIndexName, ActionListener.wrap(result -> { + createRemoteMemoryIndexes( + container, + listener, + configuration, + workingMemoryIndexName, + longTermMemoryIndexName, + longTermMemoryHistoryIndexName + ); + }, listener::onFailure)); + } else { + mlIndicesHandler.createSessionMemoryDataIndex(sessionIndexName, configuration, ActionListener.wrap(result -> { + createMemoryIndexes( + container, + listener, + configuration, + workingMemoryIndexName, + longTermMemoryIndexName, + longTermMemoryHistoryIndexName + ); + }, listener::onFailure)); + } } - } } @@ -339,4 +377,58 @@ private void validateConfiguration(MemoryConfiguration config, ActionListener listener) { + String connectorId = configuration.getRemoteStore().getConnectorId(); + RemoteStorageHelper.createRemoteSessionMemoryIndex(connectorId, indexName, configuration, mlIndicesHandler, client, listener); + } + + private void createRemoteWorkingMemoryIndex(MemoryConfiguration configuration, String indexName, ActionListener listener) { + String connectorId = configuration.getRemoteStore().getConnectorId(); + RemoteStorageHelper.createRemoteWorkingMemoryIndex(connectorId, indexName, configuration, mlIndicesHandler, client, listener); + } + + private void createRemoteLongTermMemoryHistoryIndex( + MemoryConfiguration configuration, + String indexName, + ActionListener listener + ) { + String connectorId = configuration.getRemoteStore().getConnectorId(); + RemoteStorageHelper + .createRemoteLongTermMemoryHistoryIndex(connectorId, indexName, configuration, mlIndicesHandler, client, listener); + } + + private void createRemoteMemoryIndexes( + MLMemoryContainer container, + ActionListener listener, + MemoryConfiguration configuration, + String workingMemoryIndexName, + String longTermMemoryIndexName, + String longTermMemoryHistoryIndexName + ) { + createRemoteWorkingMemoryIndex(configuration, workingMemoryIndexName, ActionListener.wrap(success -> { + // Create long-term memory index with pipeline if embedding is configured + createRemoteLongTermMemoryIngestPipeline(configuration, longTermMemoryIndexName, ActionListener.wrap(success1 -> { + if (!configuration.isDisableHistory()) { + createRemoteLongTermMemoryHistoryIndex(configuration, longTermMemoryHistoryIndexName, ActionListener.wrap(success2 -> { + listener.onResponse(longTermMemoryIndexName); + }, listener::onFailure)); + } else { + listener.onResponse(longTermMemoryIndexName); + } + }, listener::onFailure)); + }, listener::onFailure)); + } + + private void createRemoteLongTermMemoryIngestPipeline( + MemoryConfiguration configuration, + String indexName, + ActionListener listener + ) { + String connectorId = configuration.getRemoteStore().getConnectorId(); + MemoryContainerPipelineHelper + .createRemoteLongTermMemoryIngestPipeline(connectorId, indexName, configuration, mlIndicesHandler, client, listener); + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java index 0e7b690a26..8556629dee 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java @@ -74,26 +74,9 @@ private void searchFactsSequentially( String fact = facts.get(currentIndex); try { - QueryBuilder queryBuilder = MemorySearchQueryBuilder - .buildFactSearchQuery(strategy, fact, input.getNamespace(), input.getOwnerId(), memoryConfig, input.getMemoryContainerId()); - - log.debug("Searching for similar facts"); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.query(queryBuilder); - searchSourceBuilder.size(maxInferSize); - searchSourceBuilder.fetchSource(new String[] { MEMORY_FIELD }, null); - String indexName = memoryConfig.getLongMemoryIndexName(); String tenantId = memoryConfig.getTenantId(); - SearchDataObjectRequest searchRequest = SearchDataObjectRequest - .builder() - .indices(indexName) - .searchSourceBuilder(searchSourceBuilder) - .tenantId(tenantId) - .build(); - ActionListener searchResponseActionListener = ActionListener.wrap(response -> { for (SearchHit hit : response.getHits().getHits()) { Map sourceMap = hit.getSourceAsMap(); @@ -103,14 +86,52 @@ private void searchFactsSequentially( } } - log.debug("Found {} similar facts", response.getHits().getHits().length); + log.debug("Found {} similar facts for: {}", response.getHits().getHits().length, fact); searchFactsSequentially(strategy, input, facts, currentIndex + 1, memoryConfig, maxInferSize, allResults, listener); }, e -> { - log.error("Failed to search for similar facts"); + log.error("Failed to search for similar facts for: {}", fact, e); searchFactsSequentially(strategy, input, facts, currentIndex + 1, memoryConfig, maxInferSize, allResults, listener); }); - memoryContainerHelper.searchData(memoryConfig, searchRequest, searchResponseActionListener); + if (memoryConfig.getRemoteStore() == null) { + QueryBuilder queryBuilder = MemorySearchQueryBuilder + .buildFactSearchQuery( + strategy, + fact, + input.getNamespace(), + input.getOwnerId(), + memoryConfig, + input.getMemoryContainerId() + ); + + log.debug("Searching for similar facts"); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.size(maxInferSize); + searchSourceBuilder.fetchSource(new String[] { MEMORY_FIELD }, null); + + SearchDataObjectRequest searchRequest = SearchDataObjectRequest + .builder() + .indices(indexName) + .searchSourceBuilder(searchSourceBuilder) + .tenantId(tenantId) + .build(); + + memoryContainerHelper.searchData(memoryConfig, searchRequest, searchResponseActionListener); + } else { + String query = MemorySearchQueryBuilder + .buildFactSearchQueryForAoss( + strategy, + fact, + input.getNamespace(), + input.getOwnerId(), + memoryConfig, + input.getMemoryContainerId(), + maxInferSize + ); + memoryContainerHelper.searchDataFromRemoteStorage(memoryConfig, indexName, query, searchResponseActionListener); + } } catch (Exception e) { log.error("Failed to build search query for facts"); searchFactsSequentially(strategy, input, facts, currentIndex + 1, memoryConfig, maxInferSize, allResults, listener); diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java index 823c7c0548..ab48d287f9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java @@ -141,7 +141,13 @@ private void searchMemories( log.error("Search execution failed", e); actionListener.onFailure(new OpenSearchException("Search execution failed: " + e.getMessage(), e)); }); - memoryContainerHelper.searchData(container.getConfiguration(), searchDataObjecRequest, searchResponseActionListener); + + if (memoryConfig.getRemoteStore() == null) { + memoryContainerHelper.searchData(container.getConfiguration(), searchDataObjecRequest, searchResponseActionListener); + } else { + String query = input.getSearchSourceBuilder().toString(); + memoryContainerHelper.searchDataFromRemoteStorage(memoryConfig, indexName, query, searchResponseActionListener); + } } catch (Exception e) { log.error("Failed to build search request", e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportUpdateMemoryAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportUpdateMemoryAction.java index bd0df4222e..4fe38be24f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportUpdateMemoryAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportUpdateMemoryAction.java @@ -129,8 +129,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener newDoc = constructNewDoc(updateRequest.getMlUpdateMemoryInput(), memoryType, originalDoc); IndexRequest indexRequest = new IndexRequest(memoryIndexName).id(memoryId).source(newDoc); indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - memoryContainerHelper.indexData(container.getConfiguration(), indexRequest, actionListener); - + if (container.getConfiguration().getRemoteStore() == null) { + memoryContainerHelper.indexData(container.getConfiguration(), indexRequest, actionListener); + } else { + memoryContainerHelper.updateDataToRemoteStorage(container.getConfiguration(), indexRequest, actionListener); + } }, actionListener::onFailure); memoryContainerHelper.getData(container.getConfiguration(), getRequest, getResponseActionListener); diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java index 4e819f103b..d16ba9f6f0 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java @@ -16,12 +16,19 @@ import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.OWNER_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.lucene.search.join.ScoreMode; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.bulk.BulkItemResponse; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.delete.DeleteRequest; @@ -37,9 +44,12 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -57,6 +67,7 @@ import org.opensearch.ml.common.memorycontainer.MemoryConfiguration; import org.opensearch.ml.common.memorycontainer.MemoryStrategy; import org.opensearch.ml.common.memorycontainer.MemoryType; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.remote.metadata.client.GetDataObjectRequest; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.SearchDataObjectRequest; @@ -223,7 +234,9 @@ public String getMemoryIndexName(MLMemoryContainer container, MemoryType memoryT } public void getData(MemoryConfiguration configuration, GetRequest getRequest, ActionListener listener) { - if (configuration.isUseSystemIndex()) { + if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + getDataFromRemoteStorage(configuration, getRequest, listener); + } else if (configuration.isUseSystemIndex()) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.get(getRequest, ActionListener.runBefore(listener, context::restore)); } @@ -232,13 +245,39 @@ public void getData(MemoryConfiguration configuration, GetRequest getRequest, Ac } } + private void getDataFromRemoteStorage(MemoryConfiguration configuration, GetRequest getRequest, ActionListener listener) { + try { + String connectorId = configuration.getRemoteStore().getConnectorId(); + String indexName = getRequest.indices()[0]; + String docId = getRequest.id(); + + // Convert SearchSourceBuilder to Map + RemoteStorageHelper + .getDocument( + connectorId, + indexName, + docId, + client, + ActionListener.wrap(response -> { listener.onResponse(response); }, listener::onFailure) + ); + } catch (Exception e) { + log.error("Failed to search data from remote storage", e); + listener.onFailure(e); + } + } + public void searchData( MemoryConfiguration configuration, SearchDataObjectRequest searchRequest, ActionListener listener ) { try { - if (configuration.isUseSystemIndex()) { + // Check if remote store is configured + if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + // Use remote storage + // searchDataFromRemoteStorage(configuration, searchRequest, listener); + throw new RuntimeException("Remote store is not yet implemented"); + } else if (configuration.isUseSystemIndex()) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); final ActionListener doubleWrappedListener = ActionListener @@ -258,8 +297,41 @@ public void searchData( } } + public void searchDataFromRemoteStorage( + MemoryConfiguration configuration, + String indexName, + String query, + ActionListener listener + ) { + try { + String connectorId = configuration.getRemoteStore().getConnectorId(); + RemoteStorageHelper.searchDocuments(connectorId, indexName, query, client, ActionListener.wrap(response -> { + listener.onResponse(response); + }, listener::onFailure)); + } catch (Exception e) { + log.error("Failed to search data from remote storage", e); + listener.onFailure(e); + } + } + + private Map convertSearchSourceToMap(SearchSourceBuilder searchSourceBuilder) throws IOException { + if (searchSourceBuilder == null) { + return new HashMap<>(); + } + + // Convert SearchSourceBuilder to JSON string then to Map + String jsonString = searchSourceBuilder.toString(); + XContentParser parser = XContentHelper + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, new BytesArray(jsonString), XContentType.JSON); + return parser.mapOrdered(); + } + public void indexData(MemoryConfiguration configuration, IndexRequest indexRequest, ActionListener listener) { - if (configuration.isUseSystemIndex()) { + // Check if remote store is configured + if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + // Use remote storage + indexDataToRemoteStorage(configuration, indexRequest, listener); + } else if (configuration.isUseSystemIndex()) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.index(indexRequest, ActionListener.runBefore(listener, context::restore)); } @@ -268,8 +340,62 @@ public void indexData(MemoryConfiguration configuration, IndexRequest indexReque } } + public void updateDataToRemoteStorage( + MemoryConfiguration configuration, + IndexRequest indexRequest, + ActionListener listener + ) { + try { + String connectorId = configuration.getRemoteStore().getConnectorId(); + String indexName = indexRequest.index(); + String docId = indexRequest.id(); + + // Convert IndexRequest source to Map + Map documentSource = indexRequest.sourceAsMap(); + + RemoteStorageHelper + .updateDocument(connectorId, indexName, docId, documentSource, client, ActionListener.wrap(updateResponse -> { + IndexResponse response = new IndexResponse( + updateResponse.getShardId(), + updateResponse.getId(), + updateResponse.getSeqNo(), + updateResponse.getPrimaryTerm(), + updateResponse.getVersion(), + false + ); + listener.onResponse(response); + }, listener::onFailure)); + } catch (Exception e) { + log.error("Failed to index data to remote storage", e); + listener.onFailure(e); + } + } + + private void indexDataToRemoteStorage( + MemoryConfiguration configuration, + IndexRequest indexRequest, + ActionListener listener + ) { + try { + String connectorId = configuration.getRemoteStore().getConnectorId(); + String indexName = indexRequest.index(); + + // Convert IndexRequest source to Map + Map documentSource = indexRequest.sourceAsMap(); + + RemoteStorageHelper.writeDocument(connectorId, indexName, documentSource, client, ActionListener.wrap(response -> { + listener.onResponse(response); + }, listener::onFailure)); + } catch (Exception e) { + log.error("Failed to index data to remote storage", e); + listener.onFailure(e); + } + } + public void updateData(MemoryConfiguration configuration, UpdateRequest updateRequest, ActionListener listener) { - if (configuration.isUseSystemIndex()) { + if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + updateDataInRemoteStorage(configuration, updateRequest, listener); + } else if (configuration.isUseSystemIndex()) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.update(updateRequest, ActionListener.runBefore(listener, context::restore)); } @@ -278,8 +404,46 @@ public void updateData(MemoryConfiguration configuration, UpdateRequest updateRe } } + private void updateDataInRemoteStorage( + MemoryConfiguration configuration, + UpdateRequest updateRequest, + ActionListener listener + ) { + try { + String connectorId = configuration.getRemoteStore().getConnectorId(); + String indexName = updateRequest.index(); + String docId = updateRequest.id(); + + Map documentSource = convertUpdateRequestToMap(updateRequest); + RemoteStorageHelper.updateDocument(connectorId, indexName, docId, documentSource, client, ActionListener.wrap(response -> { + listener.onResponse(response); + }, listener::onFailure)); + } catch (Exception e) { + log.error("Failed to update data in remote storage", e); + listener.onFailure(e); + } + } + + private Map convertUpdateRequestToMap(UpdateRequest updateRequest) throws IOException { + Map result = new HashMap<>(); + + if (updateRequest.doc() != null) { + result.put("doc", updateRequest.doc().sourceAsMap()); + } + + // Handle upsert if present + if (updateRequest.upsertRequest() != null) { + result.put("doc_as_upsert", true); + result.put("doc", updateRequest.upsertRequest().sourceAsMap()); + } + + return result; + } + public void deleteData(MemoryConfiguration configuration, DeleteRequest deleteRequest, ActionListener listener) { - if (configuration.isUseSystemIndex()) { + if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + deleteDataFromRemoteStorage(configuration, deleteRequest, listener); + } else if (configuration.isUseSystemIndex()) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.delete(deleteRequest, ActionListener.runBefore(listener, context::restore)); } @@ -288,6 +452,30 @@ public void deleteData(MemoryConfiguration configuration, DeleteRequest deleteRe } } + private void deleteDataFromRemoteStorage( + MemoryConfiguration configuration, + DeleteRequest deleteRequest, + ActionListener listener + ) { + try { + String connectorId = configuration.getRemoteStore().getConnectorId(); + String indexName = deleteRequest.index(); + String docId = deleteRequest.id(); + + RemoteStorageHelper + .deleteDocument( + connectorId, + indexName, + docId, + client, + ActionListener.wrap(response -> { listener.onResponse(response); }, listener::onFailure) + ); + } catch (Exception e) { + log.error("Failed to delete data from remote storage", e); + listener.onFailure(e); + } + } + public void deleteIndex( MemoryConfiguration configuration, DeleteIndexRequest deleteIndexRequest, @@ -303,7 +491,9 @@ public void deleteIndex( } public void bulkIngestData(MemoryConfiguration configuration, BulkRequest bulkRequest, ActionListener listener) { - if (configuration.isUseSystemIndex()) { + if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + bulkIngestDataToRemoteStorage(configuration, bulkRequest, listener); + } else if (configuration.isUseSystemIndex()) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.bulk(bulkRequest, ActionListener.runBefore(listener, context::restore)); } @@ -312,6 +502,144 @@ public void bulkIngestData(MemoryConfiguration configuration, BulkRequest bulkRe } } + private void bulkIngestDataToRemoteStorage( + MemoryConfiguration configuration, + BulkRequest bulkRequest, + ActionListener listener + ) { + try { + String connectorId = configuration.getRemoteStore().getConnectorId(); + List bulkBodyList = convertBulkRequestToNDJSON(bulkRequest); + + if (bulkBodyList.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Empty bulk request")); + return; + } + + // Process sequentially + bulkIngestSequentially(connectorId, bulkBodyList, 0, new ArrayList<>(), listener); + + } catch (Exception e) { + log.error("Failed to bulk ingest data to remote storage", e); + listener.onFailure(e); + } + } + + private void bulkIngestSequentially( + String connectorId, + List bulkBodyList, + int index, + List responses, + ActionListener finalListener + ) { + if (index >= bulkBodyList.size()) { + // All done, merge responses + BulkResponse mergedResponse = mergeBulkResponses(responses); + finalListener.onResponse(mergedResponse); + return; + } + + RemoteStorageHelper.bulkWrite(connectorId, bulkBodyList.get(index), client, ActionListener.wrap(response -> { + responses.add(response); + // Process next + bulkIngestSequentially(connectorId, bulkBodyList, index + 1, responses, finalListener); + }, finalListener::onFailure)); + } + + private BulkResponse mergeBulkResponses(Collection responses) { + List allItems = new ArrayList<>(); + long totalTook = 0; + long totalIngestTook = 0; + boolean hasErrors = false; + + for (BulkResponse response : responses) { + allItems.addAll(Arrays.asList(response.getItems())); + totalTook += response.getTook().millis(); + totalIngestTook += response.getIngestTookInMillis(); + hasErrors |= response.hasFailures(); + } + + return new BulkResponse(allItems.toArray(new BulkItemResponse[0]), totalTook, totalIngestTook); + } + + private List convertBulkRequestToNDJSON(BulkRequest bulkRequest) { + + StringBuilder ndjsonForIndex = new StringBuilder(); + StringBuilder ndjsonForUpdateDelete = new StringBuilder(); + + boolean indexExists = false; + boolean updateDeleteExists = false; + + for (var docWriteRequest : bulkRequest.requests()) { + if (docWriteRequest instanceof IndexRequest) { + IndexRequest indexRequest = (IndexRequest) docWriteRequest; + + // Action line for index operation + Map actionMetadata = new HashMap<>(); + actionMetadata.put("_index", indexRequest.index()); + if (indexRequest.id() != null) { // TODO: throw exception AOSS doesn't support doc id + actionMetadata.put("_id", indexRequest.id()); + } + Map actionLine = new HashMap<>(); + actionLine.put("index", actionMetadata); + ndjsonForIndex.append(StringUtils.toJson(actionLine)).append('\n'); + + // Document line + ndjsonForIndex.append(indexRequest.source().utf8ToString()).append('\n'); + indexExists = true; + } else if (docWriteRequest instanceof UpdateRequest) { + UpdateRequest updateRequest = (UpdateRequest) docWriteRequest; + + // Action line for update operation + Map actionMetadata = new HashMap<>(); + actionMetadata.put("_index", updateRequest.index()); + actionMetadata.put("_id", updateRequest.id()); + Map actionLine = new HashMap<>(); + actionLine.put("update", actionMetadata); + ndjsonForUpdateDelete.append(StringUtils.toJson(actionLine)).append('\n'); + + // Document line - for update, we need to wrap in "doc" or "script" + Map updateDoc = new HashMap<>(); + if (updateRequest.doc() != null) { + updateDoc.put("doc", XContentHelper.convertToMap(updateRequest.doc().source(), false, XContentType.JSON).v2()); + if (updateRequest.docAsUpsert()) { + updateDoc.put("doc_as_upsert", true); + } + } else if (updateRequest.script() != null) { + updateDoc.put("script", updateRequest.script()); + } + if (updateRequest.upsertRequest() != null) { + updateDoc + .put("upsert", XContentHelper.convertToMap(updateRequest.upsertRequest().source(), false, XContentType.JSON).v2()); + } + ndjsonForUpdateDelete.append(StringUtils.toJson(updateDoc)).append('\n'); + updateDeleteExists = true; + } else if (docWriteRequest instanceof DeleteRequest) { + DeleteRequest deleteRequest = (DeleteRequest) docWriteRequest; + + // Action line for delete operation + Map actionMetadata = new HashMap<>(); + actionMetadata.put("_index", deleteRequest.index()); + actionMetadata.put("_id", deleteRequest.id()); + Map actionLine = new HashMap<>(); + actionLine.put("delete", actionMetadata); + ndjsonForUpdateDelete.append(StringUtils.toJson(actionLine)).append('\n'); + updateDeleteExists = true; + // Delete operations don't have a document line, just the action line + } + } + + List result = new ArrayList<>(); + if (indexExists) { + result.add(ndjsonForIndex.toString()); + } + if (updateDeleteExists) { + result.add(ndjsonForUpdateDelete.toString()); + } + + return result; + } + /** * Execute delete by query with proper system index handling * diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java index 73765164d4..fd6682c477 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java @@ -19,6 +19,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.memorycontainer.MemoryConfiguration; +import org.opensearch.ml.common.memorycontainer.RemoteStore; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.transport.client.Client; @@ -227,4 +228,112 @@ public static void createHistoryIndexIfEnabled( listener.onResponse(true); } } + + /** + * Creates an ingest pipeline in remote storage and long-term memory index. + *

+ * If embedding is configured, creates a text embedding pipeline in the remote cluster first, + * then creates the long-term index with the pipeline attached. + * If no embedding is configured, creates the index without a pipeline. + * + * @param connectorId The connector ID for remote storage + * @param indexName The long-term memory index name + * @param config The memory configuration + * @param indicesHandler The ML indices handler + * @param client The OpenSearch client + * @param listener Action listener that receives true on success, or error on failure + */ + public static void createRemoteLongTermMemoryIngestPipeline( + String connectorId, + String indexName, + MemoryConfiguration config, + MLIndicesHandler indicesHandler, + Client client, + ActionListener listener + ) { + try { + if (config.getRemoteStore().getEmbeddingModelType() != null) { + String pipelineName = indexName + "-embedding"; + + createRemoteTextEmbeddingPipeline(connectorId, pipelineName, config, client, ActionListener.wrap(success -> { + log.info("Successfully created remote text embedding pipeline: {}", pipelineName); + // Now create the remote long-term memory index with the pipeline + org.opensearch.ml.helper.RemoteStorageHelper + .createRemoteLongTermMemoryIndexWithPipeline( + connectorId, + indexName, + pipelineName, + config, + indicesHandler, + client, + listener + ); + }, e -> { + log.error("Failed to create remote text embedding pipeline '{}'", pipelineName, e); + listener.onFailure(e); + })); + } else { + // No embedding configured, create index without pipeline + org.opensearch.ml.helper.RemoteStorageHelper + .createRemoteLongTermMemoryIndex(connectorId, indexName, config, indicesHandler, client, listener); + } + } catch (Exception e) { + log.error("Failed to create remote long-term memory infrastructure for index: {}", indexName, e); + listener.onFailure(e); + } + } + + /** + * Creates a text embedding pipeline in remote storage for memory container. + *

+ * Creates a pipeline with the appropriate embedding processor in the remote cluster. + * Uses the remote embedding model ID if specified in remote_store configuration, + * otherwise falls back to the local embedding model ID. + * + * @param connectorId The connector ID for remote storage + * @param pipelineName The pipeline name + * @param config The memory configuration + * @param client The OpenSearch client + * @param listener Action listener that receives true on success, or error on failure + */ + public static void createRemoteTextEmbeddingPipeline( + String connectorId, + String pipelineName, + MemoryConfiguration config, + Client client, + ActionListener listener + ) { + try { + RemoteStore remoteStore = config.getRemoteStore(); + String processorName = remoteStore.getEmbeddingModelType() == org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING + ? "text_embedding" + : "sparse_encoding"; + + String embeddingModelId = remoteStore.getEmbeddingModelId(); + + XContentBuilder builder = XContentFactory + .jsonBuilder() + .startObject() + .field("description", "Agentic Memory Text embedding pipeline") + .startArray("processors") + .startObject() + .startObject(processorName) + .field("model_id", embeddingModelId) + .startObject("field_map") + .field(MEMORY_FIELD, MEMORY_EMBEDDING_FIELD) + .endObject() + .endObject() + .endObject() + .endArray() + .endObject(); + + String pipelineBody = builder.toString(); + + // Use RemoteStorageHelper to create the pipeline in remote storage + org.opensearch.ml.helper.RemoteStorageHelper.createRemotePipeline(connectorId, pipelineName, pipelineBody, client, listener); + } catch (IOException e) { + log.error("Failed to build remote pipeline configuration for '{}'", pipelineName, e); + listener.onFailure(e); + } + } } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java new file mode 100644 index 0000000000..b5ff675097 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java @@ -0,0 +1,705 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.helper; + +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; +import static org.opensearch.ml.common.CommonValue.CONNECTOR_ACTION_FIELD; +import static org.opensearch.ml.common.CommonValue.ML_LONG_MEMORY_HISTORY_INDEX_MAPPING_PATH; +import static org.opensearch.ml.common.CommonValue.ML_LONG_TERM_MEMORY_INDEX_MAPPING_PATH; +import static org.opensearch.ml.common.CommonValue.ML_MEMORY_SESSION_INDEX_MAPPING_PATH; +import static org.opensearch.ml.common.CommonValue.ML_WORKING_MEMORY_INDEX_MAPPING_PATH; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.LONG_TERM_MEMORY_HISTORY_INDEX; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.LONG_TERM_MEMORY_INDEX; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_EMBEDDING_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.NAMESPACE_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.NAMESPACE_SIZE_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.OWNER_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SESSION_INDEX; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.STRATEGY_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.WORKING_MEMORY_INDEX; +import static org.opensearch.ml.common.utils.ToolUtils.NO_ESCAPE_PARAMS; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import org.apache.commons.text.StringEscapeUtils; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; +import org.opensearch.ml.common.memorycontainer.MemoryConfiguration; +import org.opensearch.ml.common.memorycontainer.MemoryStrategy; +import org.opensearch.ml.common.memorycontainer.RemoteStore; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; +import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Helper class for creating memory indices in remote storage using connectors + */ +@Log4j2 +public class RemoteStorageHelper { + + private static final String CREATE_INGEST_PIPELINE_ACTION = "create_ingest_pipeline"; + private static final String CREATE_INDEX_ACTION = "create_index"; + private static final String WRITE_DOC_ACTION = "write_doc"; + private static final String BULK_LOAD_ACTION = "bulk_load"; + private static final String SEARCH_INDEX_ACTION = "search_index"; + private static final String UPDATE_DOC_ACTION = "update_doc"; + private static final String GET_DOC_ACTION = "get_doc"; + private static final String DELETE_DOC_ACTION = "delete_doc"; + private static final String INDEX_NAME_PARAM = "index_name"; + private static final String DOC_ID_PARAM = "doc_id"; + private static final String INPUT_PARAM = "input"; + + /** + * Creates a memory index in remote storage using a connector + * + * @param connectorId The connector ID to use for remote storage + * @param indexName The name of the index to create + * @param indexMapping The index mapping as a JSON string + * @param client The OpenSearch client + * @param listener The action listener + */ + public static void createRemoteIndex( + String connectorId, + String indexName, + String indexMapping, + Client client, + ActionListener listener + ) { + createRemoteIndex(connectorId, indexName, indexMapping, null, client, listener); + } + + /** + * Creates a memory index in remote storage using a connector with custom settings + * + * @param connectorId The connector ID to use for remote storage + * @param indexName The name of the index to create + * @param indexMapping The index mapping as a JSON string + * @param indexSettings The index settings as a Map (can be null) + * @param client The OpenSearch client + * @param listener The action listener + */ + public static void createRemoteIndex( + String connectorId, + String indexName, + String indexMapping, + Map indexSettings, + Client client, + ActionListener listener + ) { + try { + // Parse the mapping string to a Map + Map mappingMap = parseMappingToMap(indexMapping); + + // Build the request body for creating the index + Map requestBody = new HashMap<>(); + requestBody.put("mappings", mappingMap); + + // Add settings if provided (settings should already have "index." prefix) + if (indexSettings != null && !indexSettings.isEmpty()) { + requestBody.put("settings", indexSettings); + } + + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(INPUT_PARAM, StringUtils.toJson(requestBody)); + parameters.put(CONNECTOR_ACTION_FIELD, CREATE_INDEX_ACTION); + + // Execute the connector action + executeConnectorAction(connectorId, parameters, client, ActionListener.wrap(response -> { + log.info("Successfully created remote index: {}", indexName); + listener.onResponse(true); + }, e -> { + log.error("Failed to create remote index: {}", indexName, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote index creation for: {}", indexName, e); + listener.onFailure(e); + } + } + + /** + * Creates session memory index in remote storage + */ + public static void createRemoteSessionMemoryIndex( + String connectorId, + String indexName, + MemoryConfiguration configuration, + MLIndicesHandler mlIndicesHandler, + Client client, + ActionListener listener + ) { + String indexMappings = mlIndicesHandler.getMapping(ML_MEMORY_SESSION_INDEX_MAPPING_PATH); + Map indexSettings = configuration.getMemoryIndexMapping(SESSION_INDEX); + createRemoteIndex(connectorId, indexName, indexMappings, indexSettings, client, listener); + } + + /** + * Creates working memory index in remote storage + */ + public static void createRemoteWorkingMemoryIndex( + String connectorId, + String indexName, + MemoryConfiguration configuration, + MLIndicesHandler mlIndicesHandler, + Client client, + ActionListener listener + ) { + String indexMappings = mlIndicesHandler.getMapping(ML_WORKING_MEMORY_INDEX_MAPPING_PATH); + Map indexSettings = configuration.getMemoryIndexMapping(WORKING_MEMORY_INDEX); + createRemoteIndex(connectorId, indexName, indexMappings, indexSettings, client, listener); + } + + /** + * Creates long-term memory history index in remote storage + */ + public static void createRemoteLongTermMemoryHistoryIndex( + String connectorId, + String indexName, + MemoryConfiguration configuration, + MLIndicesHandler mlIndicesHandler, + Client client, + ActionListener listener + ) { + String indexMappings = mlIndicesHandler.getMapping(ML_LONG_MEMORY_HISTORY_INDEX_MAPPING_PATH); + Map indexSettings = configuration.getMemoryIndexMapping(LONG_TERM_MEMORY_HISTORY_INDEX); + createRemoteIndex(connectorId, indexName, indexMappings, indexSettings, client, listener); + } + + /** + * Creates long-term memory index in remote storage with dynamic embedding configuration + */ + public static void createRemoteLongTermMemoryIndex( + String connectorId, + String indexName, + MemoryConfiguration memoryConfig, + MLIndicesHandler mlIndicesHandler, + Client client, + ActionListener listener + ) { + try { + String indexMapping = buildLongTermMemoryMapping(memoryConfig, mlIndicesHandler); + Map indexSettings = buildLongTermMemorySettings(memoryConfig); + createRemoteIndex(connectorId, indexName, indexMapping, indexSettings, client, listener); + } catch (Exception e) { + log.error("Failed to build long-term memory mapping for remote index: {}", indexName, e); + listener.onFailure(e); + } + } + + /** + * Builds the long-term memory index mapping dynamically based on configuration + */ + private static String buildLongTermMemoryMapping(MemoryConfiguration memoryConfig, MLIndicesHandler mlIndicesHandler) + throws IOException { + String baseMappingJson = mlIndicesHandler.getMapping(ML_LONG_TERM_MEMORY_INDEX_MAPPING_PATH); + + Map mapping = new HashMap<>(); + Map properties = new HashMap<>(); + + XContentParser parser = XContentHelper + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + new BytesArray(baseMappingJson), + XContentType.JSON + ); + + Map baseMapping = parser.mapOrdered(); + mapping.put("_meta", baseMapping.get("_meta")); + properties.putAll((Map) baseMapping.get("properties")); + + RemoteStore remoteStore = memoryConfig.getRemoteStore(); + // Add embedding field based on configuration + if (remoteStore.getEmbeddingModelType() == FunctionName.TEXT_EMBEDDING) { + Map knnVector = new HashMap<>(); + knnVector.put("type", "knn_vector"); + knnVector.put("dimension", remoteStore.getEmbeddingDimension()); + properties.put(MEMORY_EMBEDDING_FIELD, knnVector); + } else if (remoteStore.getEmbeddingModelType() == FunctionName.SPARSE_ENCODING) { + properties.put(MEMORY_EMBEDDING_FIELD, Map.of("type", "rank_features")); + } + + mapping.put("properties", properties); + return StringUtils.toJson(mapping); + } + + /** + * Builds the long-term memory index settings dynamically based on configuration + * Returns settings with "index." prefix as required by OpenSearch/AOSS + */ + private static Map buildLongTermMemorySettings(MemoryConfiguration memoryConfig) { + Map indexSettings = new HashMap<>(); + + RemoteStore remoteStore = memoryConfig.getRemoteStore(); + if (remoteStore.getEmbeddingModelType() == FunctionName.TEXT_EMBEDDING) { + indexSettings.put("index.knn", true); + } + + // Add custom settings from configuration + if (!memoryConfig.getIndexSettings().isEmpty() && memoryConfig.getIndexSettings().containsKey(LONG_TERM_MEMORY_INDEX)) { + Map configuredIndexSettings = memoryConfig.getMemoryIndexMapping(LONG_TERM_MEMORY_INDEX); + indexSettings.putAll(configuredIndexSettings); + } + + return indexSettings; + } + + /** + * Executes a connector action with a specific action name + */ + private static void executeConnectorAction( + String connectorId, + String actionName, + Map parameters, + Client client, + ActionListener listener + ) { + // Add connector_action parameter to specify which action to execute + Map allParameters = new HashMap<>(parameters); + allParameters.put(CONNECTOR_ACTION_FIELD, actionName); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(allParameters).build(); + MLInput mlInput = RemoteInferenceMLInput.builder().algorithm(FunctionName.CONNECTOR).inputDataset(inputDataSet).build(); + MLExecuteConnectorRequest request = new MLExecuteConnectorRequest(connectorId, mlInput); + + client.execute(MLExecuteConnectorAction.INSTANCE, request, ActionListener.wrap(r -> { + ModelTensorOutput output = (ModelTensorOutput) r.getOutput(); + listener.onResponse(output); + }, e -> { + log.error("Failed to execute connector action {} for connector: {}", actionName, connectorId, e); + listener.onFailure(e); + })); + } + + /** + * Executes a connector action (backward compatibility - defaults to create_index) + */ + private static void executeConnectorAction( + String connectorId, + Map parameters, + Client client, + ActionListener listener + ) { + executeConnectorAction(connectorId, CREATE_INDEX_ACTION, parameters, client, listener); + } + + /** + * Writes a single document to remote storage + * + * @param connectorId The connector ID to use for remote storage + * @param indexName The name of the index + * @param documentSource The document source as a Map + * @param client The OpenSearch client + * @param listener The action listener + */ + public static void writeDocument( + String connectorId, + String indexName, + Map documentSource, + Client client, + ActionListener listener + ) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(INPUT_PARAM, StringUtils.toJsonWithPlainNumbers(documentSource)); + + // Execute the connector action with write_doc action name + executeConnectorAction(connectorId, WRITE_DOC_ACTION, parameters, client, ActionListener.wrap(response -> { + // Extract document ID from response + XContentParser parser = createParserFromTensorOutput(response); + IndexResponse indexResponse = IndexResponse.fromXContent(parser); + listener.onResponse(indexResponse); + }, e -> { + log.error("Failed to write document to remote index: {}", indexName, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote document write for index: {}", indexName, e); + listener.onFailure(e); + } + } + + /** + * Performs bulk write operations to remote storage + * + * @param connectorId The connector ID to use for remote storage + * @param bulkBody The bulk request body in NDJSON format + * @param client The OpenSearch client + * @param listener The action listener + */ + public static void bulkWrite(String connectorId, String bulkBody, Client client, ActionListener listener) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INPUT_PARAM, bulkBody); + parameters.put(NO_ESCAPE_PARAMS, INPUT_PARAM); + + // Execute the connector action with bulk_load action name + executeConnectorAction(connectorId, BULK_LOAD_ACTION, parameters, client, ActionListener.wrap(response -> { + log.info("Successfully executed bulk write to remote storage"); + XContentParser parser = createParserFromTensorOutput(response); + BulkResponse bulkResponse = BulkResponse.fromXContent(parser); + listener.onResponse(bulkResponse); + }, e -> { + log.error("Failed to execute bulk write to remote storage", e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote bulk write", e); + listener.onFailure(e); + } + } + + public static void searchDocuments( + String connectorId, + String indexName, + String query, + Client client, + ActionListener listener + ) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(INPUT_PARAM, query); + + // Execute the connector action with search_index action name + executeConnectorAction(connectorId, SEARCH_INDEX_ACTION, parameters, client, ActionListener.wrap(response -> { + log.info("Successfully searched documents in remote index: {}", indexName); + XContentParser parser = createParserFromTensorOutput(response); + SearchResponse searchResponse = SearchResponse.fromXContent(parser); + listener.onResponse(searchResponse); + }, e -> { + log.error("Failed to search documents in remote index: {}", indexName, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote search for index: {}", indexName, e); + listener.onFailure(e); + } + } + + /** + * Updates a document in remote storage + * + * @param connectorId The connector ID to use for remote storage + * @param indexName The name of the index + * @param docId The document ID to update + * @param documentSource The document source as a Map + * @param client The OpenSearch client + * @param listener The action listener + */ + public static void updateDocument( + String connectorId, + String indexName, + String docId, + Map documentSource, + Client client, + ActionListener listener + ) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(DOC_ID_PARAM, docId); + parameters.put(INPUT_PARAM, StringUtils.toJsonWithPlainNumbers(documentSource)); + + // Execute the connector action with update_doc action name + executeConnectorAction(connectorId, UPDATE_DOC_ACTION, parameters, client, ActionListener.wrap(response -> { + log.info("Successfully updated document in remote index: {}, doc_id: {}", indexName, docId); + XContentParser parser = createParserFromTensorOutput(response); + UpdateResponse updateResponse = UpdateResponse.fromXContent(parser); + listener.onResponse(updateResponse); + }, e -> { + log.error("Failed to update document in remote index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote document update for index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + } + } + + public static void getDocument( + String connectorId, + String indexName, + String docId, + Client client, + ActionListener listener + ) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(DOC_ID_PARAM, docId); + // input parameter is optional for delete, use empty string as default + parameters.put(INPUT_PARAM, ""); + + // Execute the connector action with delete_doc action name + executeConnectorAction(connectorId, GET_DOC_ACTION, parameters, client, ActionListener.wrap(response -> { + log.info("Successfully deleted document from remote index: {}, doc_id: {}", indexName, docId); + XContentParser parser = createParserFromTensorOutput(response); + GetResponse getResponse = GetResponse.fromXContent(parser); + listener.onResponse(getResponse); + }, e -> { + log.error("Failed to delete document from remote index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote document delete for index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + } + } + + /** + * Deletes a document from remote storage + * + * @param connectorId The connector ID to use for remote storage + * @param indexName The name of the index + * @param docId The document ID to delete + * @param client The OpenSearch client + * @param listener The action listener + */ + public static void deleteDocument( + String connectorId, + String indexName, + String docId, + Client client, + ActionListener listener + ) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(DOC_ID_PARAM, docId); + // input parameter is optional for delete, use empty string as default + parameters.put(INPUT_PARAM, ""); + + // Execute the connector action with delete_doc action name + executeConnectorAction(connectorId, DELETE_DOC_ACTION, parameters, client, ActionListener.wrap(response -> { + log.info("Successfully deleted document from remote index: {}, doc_id: {}", indexName, docId); + XContentParser parser = createParserFromTensorOutput(response); + DeleteResponse deleteResponse = DeleteResponse.fromXContent(parser); + listener.onResponse(deleteResponse); + }, e -> { + log.error("Failed to delete document from remote index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote document delete for index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + } + } + + /** + * Parses a JSON mapping string to a Map + */ + private static Map parseMappingToMap(String mappingJson) throws IOException { + XContentParser parser = XContentHelper + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, new BytesArray(mappingJson), XContentType.JSON); + return parser.mapOrdered(); + } + + public static XContentParser createParserFromTensorOutput(ModelTensorOutput output) throws IOException { + Map dataAsMap = output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); + String json = StringUtils.toJson(dataAsMap); + XContentParser parser = jsonXContent.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); + return parser; + } + + public static QueryBuilder buildFactSearchQuery( + MemoryStrategy strategy, + String fact, + Map namespace, + String ownerId, + MemoryConfiguration memoryConfig, + String memoryContainerId + ) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + + // Add filter conditions + for (String key : strategy.getNamespace()) { + if (!namespace.containsKey(key)) { + throw new IllegalArgumentException("Namespace does not contain key: " + key); + } + boolQuery.filter(QueryBuilders.termQuery(NAMESPACE_FIELD + "." + key, namespace.get(key))); + } + if (ownerId != null) { + boolQuery.filter(QueryBuilders.termQuery(OWNER_ID_FIELD, ownerId)); + } + boolQuery.filter(QueryBuilders.termQuery(NAMESPACE_SIZE_FIELD, strategy.getNamespace().size())); + // Filter by strategy_id to prevent cross-strategy interference (sufficient for uniqueness) + boolQuery.filter(QueryBuilders.termQuery(STRATEGY_ID_FIELD, strategy.getId())); + // Filter by memory_container_id to prevent cross-container access when containers share the same index prefix + if (memoryContainerId != null && !memoryContainerId.isBlank()) { + boolQuery.filter(QueryBuilders.termQuery(MEMORY_CONTAINER_ID_FIELD, memoryContainerId)); + } + + // Add the search query + if (memoryConfig != null) { + if (memoryConfig.getEmbeddingModelType() == FunctionName.TEXT_EMBEDDING) { + StringBuilder neuralSearchQuery = new StringBuilder() + .append("{\"neural\":{\"") + .append(MEMORY_EMBEDDING_FIELD) + .append("\":{\"query_text\":\"") + .append(StringEscapeUtils.escapeJson(fact)) + .append("\",\"model_id\":\"") + .append(memoryConfig.getEmbeddingModelId()) + .append("\"}}}"); + boolQuery.must(QueryBuilders.wrapperQuery(neuralSearchQuery.toString())); + } else if (memoryConfig.getEmbeddingModelType() == FunctionName.SPARSE_ENCODING) { + StringBuilder neuralSparseQuery = new StringBuilder() + .append("{\"neural_sparse\":{\"") + .append(MEMORY_EMBEDDING_FIELD) + .append("\":{\"query_text\":\"") + .append(StringEscapeUtils.escapeJson(fact)) + .append("\",\"model_id\":\"") + .append(memoryConfig.getEmbeddingModelId()) + .append("\"}}}"); + boolQuery.must(QueryBuilders.wrapperQuery(neuralSparseQuery.toString())); + } else { + throw new IllegalStateException("Unsupported embedding model type: " + memoryConfig.getEmbeddingModelType()); + } + } else { + boolQuery.must(QueryBuilders.matchQuery(MEMORY_FIELD, fact)); + } + + return boolQuery; + } + + /** + * Creates an ingest pipeline in remote storage + * + * @param connectorId The connector ID to use for remote storage + * @param pipelineName The name of the pipeline to create + * @param pipelineBody The pipeline configuration as a JSON string + * @param client The OpenSearch client + * @param listener The action listener + */ + public static void createRemotePipeline( + String connectorId, + String pipelineName, + String pipelineBody, + Client client, + ActionListener listener + ) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put("pipeline_name", pipelineName); + parameters.put(INPUT_PARAM, pipelineBody); + parameters.put(CONNECTOR_ACTION_FIELD, "create_ingest_pipeline"); + + // Execute the connector action + executeConnectorAction(connectorId, CREATE_INGEST_PIPELINE_ACTION, parameters, client, ActionListener.wrap(response -> { + log.info("Successfully created remote pipeline: {}", pipelineName); + listener.onResponse(true); + }, e -> { + log.error("Failed to create remote pipeline: {}", pipelineName, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote pipeline creation for: {}", pipelineName, e); + listener.onFailure(e); + } + } + + /** + * Creates long-term memory index in remote storage with a pipeline attached + * + * @param connectorId The connector ID to use for remote storage + * @param indexName The name of the index to create + * @param pipelineName The name of the pipeline to attach + * @param memoryConfig The memory configuration + * @param mlIndicesHandler The ML indices handler + * @param client The OpenSearch client + * @param listener The action listener + */ + public static void createRemoteLongTermMemoryIndexWithPipeline( + String connectorId, + String indexName, + String pipelineName, + MemoryConfiguration memoryConfig, + MLIndicesHandler mlIndicesHandler, + Client client, + ActionListener listener + ) { + try { + String indexMapping = buildLongTermMemoryMapping(memoryConfig, mlIndicesHandler); + Map indexSettings = buildLongTermMemorySettings(memoryConfig); + + // Parse the mapping string to a Map + Map mappingMap = parseMappingToMap(indexMapping); + + // Build the request body for creating the index with pipeline + Map requestBody = new HashMap<>(); + requestBody.put("mappings", mappingMap); + + // Add settings with default pipeline (settings already have "index." prefix) + Map settings = new HashMap<>(indexSettings); + settings.put("index.default_pipeline", pipelineName); + requestBody.put("settings", settings); + + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(INPUT_PARAM, StringUtils.toJson(requestBody)); + parameters.put(CONNECTOR_ACTION_FIELD, CREATE_INDEX_ACTION); + + // Execute the connector action + executeConnectorAction(connectorId, parameters, client, ActionListener.wrap(response -> { + log.info("Successfully created remote long-term memory index with pipeline: {}", indexName); + listener.onResponse(true); + }, e -> { + log.error("Failed to create remote long-term memory index with pipeline: {}", indexName, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote long-term memory index creation with pipeline for: {}", indexName, e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MemorySearchQueryBuilder.java b/plugin/src/main/java/org/opensearch/ml/utils/MemorySearchQueryBuilder.java index aa5d49b34d..e92260718a 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MemorySearchQueryBuilder.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MemorySearchQueryBuilder.java @@ -25,6 +25,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.memorycontainer.MemoryConfiguration; import org.opensearch.ml.common.memorycontainer.MemoryStrategy; +import org.opensearch.ml.common.memorycontainer.RemoteStore; import lombok.experimental.UtilityClass; import lombok.extern.log4j.Log4j2; @@ -181,4 +182,131 @@ public static QueryBuilder buildFactSearchQuery( return boolQuery; } + + /** + * Builds a fact search query for AOSS with neural search support + * Similar to buildFactSearchQuery but returns a JSON string for remote execution + * + * @param strategy The memory strategy containing namespace information + * @param fact The fact to search for + * @param namespace The namespace map for filtering + * @param ownerId The owner ID for filtering + * @param memoryConfig The memory storage configuration + * @param memoryContainerId The memory container ID to filter by + * @param maxInferSize Maximum number of results to return + * @return JSON string with the search query + */ + public static String buildFactSearchQueryForAoss( + MemoryStrategy strategy, + String fact, + Map namespace, + String ownerId, + MemoryConfiguration memoryConfig, + String memoryContainerId, + int maxInferSize + ) { + StringBuilder queryBuilder = new StringBuilder(); + queryBuilder.append("{\"size\":").append(maxInferSize).append(",\"query\":{\"bool\":{\"filter\":["); + + // Add filter conditions + boolean firstFilter = true; + for (String key : strategy.getNamespace()) { + if (!namespace.containsKey(key)) { + throw new IllegalArgumentException("Namespace does not contain key: " + key); + } + if (!firstFilter) { + queryBuilder.append(","); + } + queryBuilder + .append("{\"term\":{\"") + .append(NAMESPACE_FIELD) + .append(".") + .append(key) + .append("\":\"") + .append(StringEscapeUtils.escapeJson(namespace.get(key))) + .append("\"}}"); + firstFilter = false; + } + + if (ownerId != null) { + if (!firstFilter) { + queryBuilder.append(","); + } + queryBuilder + .append("{\"term\":{\"") + .append(OWNER_ID_FIELD) + .append("\":\"") + .append(StringEscapeUtils.escapeJson(ownerId)) + .append("\"}}"); + firstFilter = false; + } + + if (!firstFilter) { + queryBuilder.append(","); + } + queryBuilder.append("{\"term\":{\"").append(NAMESPACE_SIZE_FIELD).append("\":").append(strategy.getNamespace().size()).append("}}"); + + // Filter by strategy_id to prevent cross-strategy interference (sufficient for uniqueness) + queryBuilder + .append(",{\"term\":{\"") + .append(STRATEGY_ID_FIELD) + .append("\":\"") + .append(StringEscapeUtils.escapeJson(strategy.getId())) + .append("\"}}"); + + // Filter by memory_container_id to prevent cross-container access when containers share the same index prefix + if (memoryContainerId != null && !memoryContainerId.isBlank()) { + queryBuilder + .append(",{\"term\":{\"") + .append(MEMORY_CONTAINER_ID_FIELD) + .append("\":\"") + .append(StringEscapeUtils.escapeJson(memoryContainerId)) + .append("\"}}"); + } + + queryBuilder.append("],\"must\":["); + + RemoteStore remoteStore = memoryConfig.getRemoteStore(); + // Add the search query based on embedding type + if (remoteStore != null && remoteStore.getEmbeddingModelId() != null) { + // Determine which embedding model ID to use + String embeddingModelId = remoteStore.getEmbeddingModelId(); + + if (remoteStore.getEmbeddingModelType() == FunctionName.TEXT_EMBEDDING) { + // Neural search for dense embeddings + queryBuilder + .append("{\"neural\":{\"") + .append(MEMORY_EMBEDDING_FIELD) + .append("\":{\"query_text\":\"") + .append(StringEscapeUtils.escapeJson(fact)) + .append("\",\"model_id\":\"") + .append(StringEscapeUtils.escapeJson(embeddingModelId)) + .append("\"}}}"); + } else if (remoteStore.getEmbeddingModelType() == FunctionName.SPARSE_ENCODING) { + // Neural sparse search for sparse embeddings + queryBuilder + .append("{\"neural_sparse\":{\"") + .append(MEMORY_EMBEDDING_FIELD) + .append("\":{\"query_text\":\"") + .append(StringEscapeUtils.escapeJson(fact)) + .append("\",\"model_id\":\"") + .append(StringEscapeUtils.escapeJson(embeddingModelId)) + .append("\"}}}"); + } else { + throw new IllegalStateException("Unsupported embedding model type: " + memoryConfig.getEmbeddingModelType()); + } + } else { + // Fallback to match query if no embedding configured + queryBuilder + .append("{\"match\":{\"") + .append(MEMORY_FIELD) + .append("\":\"") + .append(StringEscapeUtils.escapeJson(fact)) + .append("\"}}"); + } + + queryBuilder.append("]}}}"); + + return queryBuilder.toString(); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/helper/MemoryContainerHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/MemoryContainerHelperTests.java index e85c62c108..a206a0dcfd 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/MemoryContainerHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/MemoryContainerHelperTests.java @@ -723,4 +723,129 @@ private SearchResponse createSearchResponse(long totalHits) { SearchResponse.Clusters.EMPTY ); } + + public void testConvertBulkRequestToNDJSON_IndexRequest() throws Exception { + BulkRequest bulkRequest = new BulkRequest(); + + // Add index request with ID + IndexRequest indexRequest1 = new IndexRequest("test-index").id("doc1").source("{\"field\":\"value1\"}"); + bulkRequest.add(indexRequest1); + + // Add index request without ID + IndexRequest indexRequest2 = new IndexRequest("test-index").source("{\"field\":\"value2\"}"); + bulkRequest.add(indexRequest2); + + String ndjson = invokeConvertBulkRequestToNDJSON(bulkRequest); + + // Verify the NDJSON format + String[] lines = ndjson.split("\n"); + assertEquals(4, lines.length); // 2 requests * 2 lines each + + // First request with ID + assertTrue(lines[0].contains("\"index\"")); + assertTrue(lines[0].contains("\"_index\":\"test-index\"")); + assertTrue(lines[0].contains("\"_id\":\"doc1\"")); + assertEquals("{\"field\":\"value1\"}", lines[1]); + + // Second request without ID + assertTrue(lines[2].contains("\"index\"")); + assertTrue(lines[2].contains("\"_index\":\"test-index\"")); + assertFalse(lines[2].contains("\"_id\"")); + assertEquals("{\"field\":\"value2\"}", lines[3]); + } + + public void testConvertBulkRequestToNDJSON_UpdateRequest() throws Exception { + BulkRequest bulkRequest = new BulkRequest(); + + // Add update request with doc + UpdateRequest updateRequest = new UpdateRequest("test-index", "doc1").doc("{\"field\":\"updated_value\"}"); + bulkRequest.add(updateRequest); + + String ndjson = invokeConvertBulkRequestToNDJSON(bulkRequest); + + // Verify the NDJSON format + String[] lines = ndjson.split("\n"); + assertEquals(2, lines.length); + + assertTrue(lines[0].contains("\"update\"")); + assertTrue(lines[0].contains("\"_index\":\"test-index\"")); + assertTrue(lines[0].contains("\"_id\":\"doc1\"")); + assertTrue(lines[1].contains("\"doc\"")); + assertTrue(lines[1].contains("\"field\":\"updated_value\"")); + } + + public void testConvertBulkRequestToNDJSON_DeleteRequest() throws Exception { + BulkRequest bulkRequest = new BulkRequest(); + + // Add delete request + DeleteRequest deleteRequest = new DeleteRequest("test-index", "doc1"); + bulkRequest.add(deleteRequest); + + String ndjson = invokeConvertBulkRequestToNDJSON(bulkRequest); + + // Verify the NDJSON format + String[] lines = ndjson.split("\n"); + assertEquals(1, lines.length); // Delete only has action line, no document line + + assertTrue(lines[0].contains("\"delete\"")); + assertTrue(lines[0].contains("\"_index\":\"test-index\"")); + assertTrue(lines[0].contains("\"_id\":\"doc1\"")); + } + + public void testConvertBulkRequestToNDJSON_MixedRequests() throws Exception { + BulkRequest bulkRequest = new BulkRequest(); + + // Add different types of requests + bulkRequest.add(new IndexRequest("test-index").id("doc1").source("{\"field\":\"value1\"}")); + bulkRequest.add(new UpdateRequest("test-index", "doc2").doc("{\"field\":\"updated\"}")); + bulkRequest.add(new DeleteRequest("test-index", "doc3")); + + String ndjson = invokeConvertBulkRequestToNDJSON(bulkRequest); + + // Verify the NDJSON format + String[] lines = ndjson.split("\n"); + assertEquals(5, lines.length); // index(2) + update(2) + delete(1) + + // Index request + assertTrue(lines[0].contains("\"index\"")); + assertEquals("{\"field\":\"value1\"}", lines[1]); + + // Update request + assertTrue(lines[2].contains("\"update\"")); + assertTrue(lines[3].contains("\"doc\"")); + + // Delete request + assertTrue(lines[4].contains("\"delete\"")); + } + + public void testConvertBulkRequestToNDJSON_RealWorldExample() throws Exception { + BulkRequest bulkRequest = new BulkRequest(); + + // Real-world example similar to memory container bulk write + String docJson = + "{\"created_time\":1760760969707,\"memory\":\"Bob likes swimming.\",\"last_updated_time\":1760760969707,\"namespace_size\":1,\"owner_id\":\"admin\",\"namespace\":{\"user_id\":\"bob\"},\"strategy_id\":\"semantic_9766b0fe\",\"strategy_type\":\"SEMANTIC\",\"memory_container_id\":\"B4DU85kBZsSZwpNve_T0\",\"tags\":{\"topic\":\"personal info\"}}"; + + IndexRequest indexRequest = new IndexRequest("demo2-memory-long-term").source(docJson); + bulkRequest.add(indexRequest); + + String ndjson = invokeConvertBulkRequestToNDJSON(bulkRequest); + + // Verify it's valid NDJSON + String[] lines = ndjson.split("\n"); + assertEquals(2, lines.length); + + assertTrue(lines[0].contains("\"index\"")); + assertTrue(lines[0].contains("\"_index\":\"demo2-memory-long-term\"")); + assertEquals(docJson, lines[1]); + + // Verify the entire NDJSON is valid using StringUtils + assertTrue(org.opensearch.ml.common.utils.StringUtils.isJsonOrNdjson(ndjson)); + } + + // Helper method to invoke private convertBulkRequestToNDJSON method + private String invokeConvertBulkRequestToNDJSON(BulkRequest bulkRequest) throws Exception { + java.lang.reflect.Method method = MemoryContainerHelper.class.getDeclaredMethod("convertBulkRequestToNDJSON", BulkRequest.class); + method.setAccessible(true); + return (String) method.invoke(helper, bulkRequest); + } } From c13a7af70841a313b8d7116fcd8609678d6f77b6 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Wed, 22 Oct 2025 12:35:09 -0700 Subject: [PATCH 03/58] auto create connector Signed-off-by: Yaliang Wu --- .../common/memorycontainer/RemoteStore.java | 63 +++++- .../TransportCreateMemoryContainerAction.java | 209 ++++++++++++++++++ 2 files changed, 271 insertions(+), 1 deletion(-) diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java index 9f140b5505..15bc4c2df5 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.EMBEDDING_MODEL_TYPE_FIELD; import java.io.IOException; +import java.util.Map; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -33,12 +34,20 @@ public class RemoteStore implements ToXContentObject, Writeable { public static final String TYPE_FIELD = "type"; public static final String CONNECTOR_ID_FIELD = "connector_id"; + public static final String ENDPOINT_FIELD = "endpoint"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String CREDENTIAL_FIELD = "credential"; private RemoteStoreType type; private String connectorId; private FunctionName embeddingModelType; private String embeddingModelId; private Integer embeddingDimension; + + // Auto-connector creation fields + private String endpoint; + private Map parameters; + private Map credential; @Builder public RemoteStore( @@ -46,7 +55,10 @@ public RemoteStore( String connectorId, FunctionName embeddingModelType, String embeddingModelId, - Integer embeddingDimension + Integer embeddingDimension, + String endpoint, + Map parameters, + Map credential ) { if (type == null) { throw new IllegalArgumentException("Invalid remote store type"); @@ -56,6 +68,9 @@ public RemoteStore( this.embeddingModelType = embeddingModelType; this.embeddingModelId = embeddingModelId; this.embeddingDimension = embeddingDimension; + this.endpoint = endpoint; + this.parameters = parameters != null ? new java.util.HashMap<>(parameters) : new java.util.HashMap<>(); + this.credential = credential != null ? new java.util.HashMap<>(credential) : new java.util.HashMap<>(); } public RemoteStore(StreamInput input) throws IOException { @@ -66,6 +81,17 @@ public RemoteStore(StreamInput input) throws IOException { } this.embeddingModelId = input.readOptionalString(); this.embeddingDimension = input.readOptionalInt(); + this.endpoint = input.readOptionalString(); + if (input.readBoolean()) { + this.parameters = input.readMap(StreamInput::readString, StreamInput::readString); + } else { + this.parameters = new java.util.HashMap<>(); + } + if (input.readBoolean()) { + this.credential = input.readMap(StreamInput::readString, StreamInput::readString); + } else { + this.credential = new java.util.HashMap<>(); + } } @Override @@ -80,6 +106,19 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalString(embeddingModelId); out.writeOptionalInt(embeddingDimension); + out.writeOptionalString(endpoint); + if (parameters != null && !parameters.isEmpty()) { + out.writeBoolean(true); + out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + if (credential != null && !credential.isEmpty()) { + out.writeBoolean(true); + out.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } } @Override @@ -100,6 +139,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (embeddingDimension != null) { builder.field(DIMENSION_FIELD, embeddingDimension); } + if (endpoint != null) { + builder.field(ENDPOINT_FIELD, endpoint); + } + if (parameters != null && !parameters.isEmpty()) { + builder.field(PARAMETERS_FIELD, parameters); + } + // Don't serialize credentials for security - they are stored in the connector builder.endObject(); return builder; } @@ -110,6 +156,9 @@ public static RemoteStore parse(XContentParser parser) throws IOException { FunctionName embeddingModelType = null; String embeddingModelId = null; Integer embeddingDimension = null; + String endpoint = null; + Map parameters = new java.util.HashMap<>(); + Map credential = new java.util.HashMap<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -132,6 +181,15 @@ public static RemoteStore parse(XContentParser parser) throws IOException { case DIMENSION_FIELD: embeddingDimension = parser.intValue(); break; + case ENDPOINT_FIELD: + endpoint = parser.text(); + break; + case PARAMETERS_FIELD: + parameters = parser.mapStrings(); + break; + case CREDENTIAL_FIELD: + credential = parser.mapStrings(); + break; default: parser.skipChildren(); break; @@ -145,6 +203,9 @@ public static RemoteStore parse(XContentParser parser) throws IOException { .embeddingModelType(embeddingModelType) .embeddingModelId(embeddingModelId) .embeddingDimension(embeddingDimension) + .endpoint(endpoint) + .parameters(parameters) + .credential(credential) .build(); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java index a06bbdf010..8b76beaa81 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java @@ -24,6 +24,7 @@ import org.opensearch.ml.common.memorycontainer.MLMemoryContainer; import org.opensearch.ml.common.memorycontainer.MemoryConfiguration; import org.opensearch.ml.common.memorycontainer.MemoryStrategy; +import org.opensearch.ml.common.memorycontainer.RemoteStore; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerAction; import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerInput; @@ -329,6 +330,24 @@ private void indexMemoryContainer(MLMemoryContainer container, ActionListener listener) { + // Check if we need to auto-create a connector + if (config.getRemoteStore() != null && config.getRemoteStore().getConnectorId() == null + && config.getRemoteStore().getEndpoint() != null) { + // Auto-create connector first + createConnectorForRemoteStore(config.getRemoteStore(), ActionListener.wrap(connectorId -> { + // Set the connector ID in the remote store config + config.getRemoteStore().setConnectorId(connectorId); + log.info("Auto-created connector with ID: {} for remote store", connectorId); + // Continue with normal validation + validateConfigurationInternal(config, listener); + }, listener::onFailure)); + } else { + // Normal validation flow + validateConfigurationInternal(config, listener); + } + } + + private void validateConfigurationInternal(MemoryConfiguration config, ActionListener listener) { // Validate that strategies have required AI models try { MemoryConfiguration.validateStrategiesRequireModels(config); @@ -431,4 +450,194 @@ private void createRemoteLongTermMemoryIngestPipeline( .createRemoteLongTermMemoryIngestPipeline(connectorId, indexName, configuration, mlIndicesHandler, client, listener); } + /** + * Auto-creates a connector for remote store based on provided endpoint and credentials + */ + private void createConnectorForRemoteStore(RemoteStore remoteStore, ActionListener listener) { + try { + String connectorName = "auto_" + remoteStore.getType().name().toLowerCase() + "_connector_" + + java.util.UUID.randomUUID().toString().substring(0, 8); + + // Build connector actions based on remote store type + java.util.List actions = buildConnectorActions(remoteStore); + + // Get credential and parameters from remote store + java.util.Map credential = remoteStore.getCredential(); + java.util.Map parameters = remoteStore.getParameters(); + + // Determine protocol based on parameters or credential + String protocol = determineProtocol(parameters, credential); + + // Create connector input + org.opensearch.ml.common.transport.connector.MLCreateConnectorInput connectorInput = + org.opensearch.ml.common.transport.connector.MLCreateConnectorInput.builder() + .name(connectorName) + .description("Auto-generated connector for " + remoteStore.getType() + " remote store") + .version("1") + .protocol(protocol) + .parameters(parameters) + .credential(credential) + .actions(actions) + .build(); + + // Create connector request + org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest request = + new org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest(connectorInput); + + // Execute connector creation + client.execute( + org.opensearch.ml.common.transport.connector.MLCreateConnectorAction.INSTANCE, + request, + ActionListener.wrap( + response -> { + log.info("Successfully created connector: {}", response.getConnectorId()); + listener.onResponse(response.getConnectorId()); + }, + e -> { + log.error("Failed to create connector for remote store", e); + listener.onFailure(e); + } + ) + ); + } catch (Exception e) { + log.error("Error building connector for remote store", e); + listener.onFailure(e); + } + } + + /** + * Determines the protocol based on parameters and credentials + */ + private String determineProtocol(java.util.Map parameters, java.util.Map credential) { + // Check if service_name is in parameters (indicates AWS SigV4) + if (parameters != null && parameters.containsKey("service_name")) { + return "aws_sigv4"; + } + // Check if roleArn is in credential (indicates AWS SigV4) + if (credential != null && credential.containsKey("roleArn")) { + return "aws_sigv4"; + } + // Check if access_key and secret_key are in credential (indicates AWS SigV4) + if (credential != null && credential.containsKey("access_key") && credential.containsKey("secret_key")) { + return "aws_sigv4"; + } + // Default to http (for basic auth or other) + return "http"; + } + + /** + * Builds connector actions based on remote store type + */ + private java.util.List buildConnectorActions(RemoteStore remoteStore) { + java.util.List actions = new java.util.ArrayList<>(); + String endpoint = remoteStore.getEndpoint(); + java.util.Map parameters = remoteStore.getParameters(); + java.util.Map credential = remoteStore.getCredential(); + + // Determine if AWS SigV4 or basic auth + boolean isAwsSigV4 = (parameters != null && parameters.containsKey("service_name")) + || (credential != null && (credential.containsKey("roleArn") || credential.containsKey("access_key"))); + boolean isBasicAuth = credential != null && credential.containsKey("basic_auth_key"); + + // Common headers for JSON + java.util.Map jsonHeaders = new java.util.HashMap<>(); + jsonHeaders.put("content-type", "application/json"); + if (isAwsSigV4) { + jsonHeaders.put("x-amz-content-sha256", "required"); + } + if (isBasicAuth) { + jsonHeaders.put("Authorization", "Basic ${credential.basic_auth_key}"); + } + + // Create ingest pipeline action + actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("create_ingest_pipeline") + .method("PUT") + .url(endpoint + "/_ingest/pipeline/${parameters.pipeline_name}") + .headers(jsonHeaders) + .requestBody("${parameters.input}") + .build()); + + // Create index action + actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("create_index") + .method("PUT") + .url(endpoint + "/${parameters.index_name}") + .headers(jsonHeaders) + .requestBody("${parameters.input}") + .build()); + + // Write doc action + actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("write_doc") + .method("POST") + .url(endpoint + "/${parameters.index_name}/_doc") + .headers(jsonHeaders) + .requestBody("${parameters.input}") + .build()); + + // Bulk load action + java.util.Map bulkHeaders = new java.util.HashMap<>(); + bulkHeaders.put("content-type", "application/x-ndjson"); + if (isAwsSigV4) { + bulkHeaders.put("x-amz-content-sha256", "required"); + } + if (isBasicAuth) { + bulkHeaders.put("Authorization", "Basic ${credential.auth_key}"); + } + + actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("bulk_load") + .method("POST") + .url(endpoint + "/_bulk") + .headers(bulkHeaders) + .requestBody("${parameters.input}") + .build()); + + // Search index action + actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("search_index") + .method("POST") + .url(endpoint + "/${parameters.index_name}/_search") + .headers(jsonHeaders) + .requestBody("${parameters.input}") + .build()); + + // Get doc action + actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("get_doc") + .method("GET") + .url(endpoint + "/${parameters.index_name}/_doc/${parameters.doc_id}") + .headers(jsonHeaders) + .build()); + + // Delete doc action + actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("delete_doc") + .method("DELETE") + .url(endpoint + "/${parameters.index_name}/_doc/${parameters.doc_id}") + .headers(jsonHeaders) + .build()); + + // Update doc action - POST /_update/ works on both OpenSearch and AOSS + // Uses partial update with "doc" wrapper for flexibility + actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("update_doc") + .method("POST") + .url(endpoint + "/${parameters.index_name}/_update/${parameters.doc_id}") + .headers(jsonHeaders) + .requestBody("{ \"doc\": ${parameters.input:-} }") + .build()); + + return actions; + } + } From 39bd765b8dd2c653f5d3c5dda1ca32440a1396b0 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Wed, 22 Oct 2025 18:13:35 -0700 Subject: [PATCH 04/58] auto create embedding model in remote store Signed-off-by: Yaliang Wu --- .../memorycontainer/RemoteEmbeddingModel.java | 181 ++++++++++++ .../common/memorycontainer/RemoteStore.java | 27 +- .../main/resources/model-connectors/README.md | 142 +++++++++ .../amazon.titan-embed-text-v2-0.json | 20 ++ .../algorithms/remote/ConnectorUtils.java | 38 +-- .../TransportCreateMemoryContainerAction.java | 276 +++++++++++------- .../ml/helper/RemoteStorageHelper.java | 171 ++++++++++- 7 files changed, 729 insertions(+), 126 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteEmbeddingModel.java create mode 100644 common/src/main/resources/model-connectors/README.md create mode 100644 common/src/main/resources/model-connectors/bedrock/text_embedding/amazon.titan-embed-text-v2-0.json diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteEmbeddingModel.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteEmbeddingModel.java new file mode 100644 index 0000000000..9cdd7d069f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteEmbeddingModel.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memorycontainer; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.utils.StringUtils; + +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; + +/** + * Configuration for embedding model in remote store + * Allows automatic creation of embedding model in remote AOSS collection + */ +@Data +@EqualsAndHashCode +public class RemoteEmbeddingModel implements ToXContentObject, Writeable { + + public static final String PROVIDER_FIELD = "model_provider"; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String MODEL_TYPE_FIELD = "model_type"; + public static final String DIMENSION_FIELD = "dimension"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String CREDENTIAL_FIELD = "credential"; + + private String provider; // e.g., "bedrock", "openai", "cohere" + private String modelId; // e.g., "amazon.titan-embed-text-v2:0" + private FunctionName modelType; // TEXT_EMBEDDING or SPARSE_ENCODING + private Integer dimension; + private Map parameters; + private Map credential; + + @Builder + public RemoteEmbeddingModel( + String provider, + String modelId, + FunctionName modelType, + Integer dimension, + Map parameters, + Map credential + ) { + this.provider = provider; + this.modelId = modelId; + this.modelType = modelType; + this.dimension = dimension; + this.parameters = parameters != null ? new HashMap<>(parameters) : new HashMap<>(); + this.credential = credential != null ? new HashMap<>(credential) : new HashMap<>(); + } + + public RemoteEmbeddingModel(StreamInput input) throws IOException { + this.provider = input.readOptionalString(); + this.modelId = input.readOptionalString(); + if (input.readOptionalBoolean()) { + this.modelType = input.readEnum(FunctionName.class); + } + this.dimension = input.readOptionalInt(); + if (input.readBoolean()) { + this.parameters = input.readMap(StreamInput::readString, StreamInput::readString); + } else { + this.parameters = new HashMap<>(); + } + if (input.readBoolean()) { + this.credential = input.readMap(StreamInput::readString, StreamInput::readString); + } else { + this.credential = new HashMap<>(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(provider); + out.writeOptionalString(modelId); + if (modelType != null) { + out.writeBoolean(true); + out.writeEnum(modelType); + } else { + out.writeBoolean(false); + } + out.writeOptionalInt(dimension); + if (parameters != null && !parameters.isEmpty()) { + out.writeBoolean(true); + out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + if (credential != null && !credential.isEmpty()) { + out.writeBoolean(true); + out.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (provider != null) { + builder.field(PROVIDER_FIELD, provider); + } + if (modelId != null) { + builder.field(MODEL_ID_FIELD, modelId); + } + if (modelType != null) { + builder.field(MODEL_TYPE_FIELD, modelType); + } + if (dimension != null) { + builder.field(DIMENSION_FIELD, dimension); + } + if (parameters != null && !parameters.isEmpty()) { + builder.field(PARAMETERS_FIELD, parameters); + } + // Don't serialize credentials for security + builder.endObject(); + return builder; + } + + public static RemoteEmbeddingModel parse(XContentParser parser) throws IOException { + String provider = null; + String modelId = null; + FunctionName modelType = null; + Integer dimension = null; + Map parameters = new HashMap<>(); + Map credential = new HashMap<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case PROVIDER_FIELD: + provider = parser.text(); + break; + case MODEL_ID_FIELD: + modelId = parser.text(); + break; + case MODEL_TYPE_FIELD: + modelType = FunctionName.from(parser.text()); + break; + case DIMENSION_FIELD: + dimension = parser.intValue(); + break; + case PARAMETERS_FIELD: + parameters = StringUtils.getParameterMap(parser.map()); + break; + case CREDENTIAL_FIELD: + credential = StringUtils.getParameterMap(parser.map()); + break; + default: + parser.skipChildren(); + break; + } + } + + return RemoteEmbeddingModel + .builder() + .provider(provider) + .modelId(modelId) + .modelType(modelType) + .dimension(dimension) + .parameters(parameters) + .credential(credential) + .build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java index 15bc4c2df5..264af6c3b9 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java @@ -37,18 +37,22 @@ public class RemoteStore implements ToXContentObject, Writeable { public static final String ENDPOINT_FIELD = "endpoint"; public static final String PARAMETERS_FIELD = "parameters"; public static final String CREDENTIAL_FIELD = "credential"; + public static final String EMBEDDING_MODEL_FIELD = "embedding_model"; private RemoteStoreType type; private String connectorId; private FunctionName embeddingModelType; private String embeddingModelId; private Integer embeddingDimension; - + // Auto-connector creation fields private String endpoint; private Map parameters; private Map credential; + // Auto embedding model creation + private RemoteEmbeddingModel embeddingModel; + @Builder public RemoteStore( RemoteStoreType type, @@ -58,7 +62,8 @@ public RemoteStore( Integer embeddingDimension, String endpoint, Map parameters, - Map credential + Map credential, + RemoteEmbeddingModel embeddingModel ) { if (type == null) { throw new IllegalArgumentException("Invalid remote store type"); @@ -71,6 +76,7 @@ public RemoteStore( this.endpoint = endpoint; this.parameters = parameters != null ? new java.util.HashMap<>(parameters) : new java.util.HashMap<>(); this.credential = credential != null ? new java.util.HashMap<>(credential) : new java.util.HashMap<>(); + this.embeddingModel = embeddingModel; } public RemoteStore(StreamInput input) throws IOException { @@ -92,6 +98,9 @@ public RemoteStore(StreamInput input) throws IOException { } else { this.credential = new java.util.HashMap<>(); } + if (input.readBoolean()) { + this.embeddingModel = new RemoteEmbeddingModel(input); + } } @Override @@ -119,6 +128,12 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + if (embeddingModel != null) { + out.writeBoolean(true); + embeddingModel.writeTo(out); + } else { + out.writeBoolean(false); + } } @Override @@ -145,6 +160,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (parameters != null && !parameters.isEmpty()) { builder.field(PARAMETERS_FIELD, parameters); } + if (embeddingModel != null) { + builder.field(EMBEDDING_MODEL_FIELD, embeddingModel); + } // Don't serialize credentials for security - they are stored in the connector builder.endObject(); return builder; @@ -159,6 +177,7 @@ public static RemoteStore parse(XContentParser parser) throws IOException { String endpoint = null; Map parameters = new java.util.HashMap<>(); Map credential = new java.util.HashMap<>(); + RemoteEmbeddingModel embeddingModel = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -190,6 +209,9 @@ public static RemoteStore parse(XContentParser parser) throws IOException { case CREDENTIAL_FIELD: credential = parser.mapStrings(); break; + case EMBEDDING_MODEL_FIELD: + embeddingModel = RemoteEmbeddingModel.parse(parser); + break; default: parser.skipChildren(); break; @@ -206,6 +228,7 @@ public static RemoteStore parse(XContentParser parser) throws IOException { .endpoint(endpoint) .parameters(parameters) .credential(credential) + .embeddingModel(embeddingModel) .build(); } } diff --git a/common/src/main/resources/model-connectors/README.md b/common/src/main/resources/model-connectors/README.md new file mode 100644 index 0000000000..20b59c1074 --- /dev/null +++ b/common/src/main/resources/model-connectors/README.md @@ -0,0 +1,142 @@ +# Model Connector Templates + +This directory contains connector templates for various AI model providers and functions. + +## Directory Structure + +``` +model-connectors/ +├── / +│ ├── .json +│ └── ... +└── ... +``` + +### Components + +- **provider**: The model provider (e.g., `bedrock/text_embedding`, `openai`, `cohere`, `huggingface`) +- **model_id**: The specific model identifier with `:` replaced by `-` (e.g., `amazon.titan-embed-text-v2` for `amazon.titan-embed-text-v2:0`) + +## Current Templates + +### Bedrock + +#### Text Embedding +- `bedrock/text_embedding/amazon.titan-embed-text-v2-0.json` - Amazon Titan Text Embeddings v2 + +## Template Format + +Each template is a JSON file containing the base connector configuration: + +```json +{ + "name": "Provider Model Name", + "description": "Description of the connector", + "version": 1, + "protocol": "aws_sigv4", + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://...", + "headers": {...}, + "request_body": "...", + "pre_process_function": "...", + "post_process_function": "..." + } + ] +} +``` + +**Note**: The `parameters` and `credential` blocks are NOT included in the template. They are provided by the user and injected at runtime. + +## Adding New Templates + +### 1. Create Directory Structure + +```bash +mkdir -p common/src/main/resources/model-connectors// +``` + +### 2. Create Template File + +Create a JSON file named `.json` (with `:` replaced by `-`): + +```bash +# Example for OpenAI text-embedding-3-small +common/src/main/resources/model-connectors/openai/text_embedding/text-embedding-3-small.json +``` + +### 3. Template Content + +Include only the base structure: +- `name`, `description`, `version` +- `protocol` +- `actions` array with action definitions + +Do NOT include: +- `parameters` block (user-provided) +- `credential` block (user-provided) + +### 4. Placeholders in Actions + +Use placeholders in action URLs and request bodies: +- `${parameters.param_name}` - Will be filled from user-provided parameters +- `${credential.cred_name}` - Will be filled from user-provided credentials + +## Examples + +### Bedrock Titan Embedding + +**File**: `bedrock/text_embedding/amazon.titan-embed-text-v2.json` + +```json +{ + "name": "Amazon Bedrock Connector: embedding", + "description": "Connector to bedrock embedding model", + "version": 1, + "protocol": "aws_sigv4", + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}, \"normalize\": ${parameters.normalize}, \"embeddingTypes\": ${parameters.embeddingTypes} }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] +} +``` + +### Future: OpenAI Embedding (Example) + +**File**: `openai/text_embedding/text-embedding-3-small.json` + +```json +{ + "name": "OpenAI Embedding Connector", + "description": "Connector to OpenAI embedding model", + "version": 1, + "protocol": "http", + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://api.openai.com/v1/embeddings", + "headers": { + "Authorization": "Bearer ${credential.api_key}", + "content-type": "application/json" + }, + "request_body": "{ \"input\": \"${parameters.input}\", \"model\": \"${parameters.model}\" }", + "pre_process_function": "connector.pre_process.openai.embedding", + "post_process_function": "connector.post_process.openai.embedding" + } + ] +} +``` + diff --git a/common/src/main/resources/model-connectors/bedrock/text_embedding/amazon.titan-embed-text-v2-0.json b/common/src/main/resources/model-connectors/bedrock/text_embedding/amazon.titan-embed-text-v2-0.json new file mode 100644 index 0000000000..09d076a79b --- /dev/null +++ b/common/src/main/resources/model-connectors/bedrock/text_embedding/amazon.titan-embed-text-v2-0.json @@ -0,0 +1,20 @@ +{ + "name": "Amazon Bedrock Connector: embedding", + "description": "Connector to bedrock embedding model", + "version": 1, + "protocol": "aws_sigv4", + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", + "headers": { + "content-type": "application/json", + "x-amz-content-sha256": "required" + }, + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}, \"normalize\": ${parameters.normalize}, \"embeddingTypes\": ${parameters.embeddingTypes} }", + "pre_process_function": "connector.pre_process.bedrock.embedding", + "post_process_function": "connector.post_process.bedrock.embedding" + } + ] +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index cc0736c9e2..88c0ce7587 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -169,11 +169,15 @@ private static MLInput escapeMLInput(MLInput mlInput) { } public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet inputData) { - if (inputData.getParameters() == null) { - return; + inputData.setParameters(escapeRemoteInferenceInputData(inputData.getParameters())); + } + + public static Map escapeRemoteInferenceInputData(Map parameters) { + if (parameters == null) { + return parameters; } Map newParameters = new HashMap<>(); - String noEscapeParams = inputData.getParameters().get(NO_ESCAPE_PARAMS); + String noEscapeParams = parameters.get(NO_ESCAPE_PARAMS); Set noEscapParamSet = new HashSet<>(); if (noEscapeParams != null && !noEscapeParams.isEmpty()) { String[] keys = noEscapeParams.split(","); @@ -181,21 +185,19 @@ public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet in noEscapParamSet.add(key.trim()); } } - if (inputData.getParameters() != null) { - inputData.getParameters().forEach((key, value) -> { - if (value == null) { - newParameters.put(key, null); - } else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) { - // no need to escape if it's already valid json - newParameters.put(key, value); - } else if (!noEscapParamSet.contains(key)) { - newParameters.put(key, escapeJson(value)); - } else { - newParameters.put(key, value); - } - }); - inputData.setParameters(newParameters); - } + parameters.forEach((key, value) -> { + if (value == null) { + newParameters.put(key, null); + } else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) { + // no need to escape if it's already valid json + newParameters.put(key, value); + } else if (!noEscapParamSet.contains(key)) { + newParameters.put(key, escapeJson(value)); + } else { + newParameters.put(key, value); + } + }); + return newParameters; } private static String getPreprocessFunction(String action, MLInput mlInput, Connector connector) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java index 8b76beaa81..8480a27c22 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java @@ -24,6 +24,7 @@ import org.opensearch.ml.common.memorycontainer.MLMemoryContainer; import org.opensearch.ml.common.memorycontainer.MemoryConfiguration; import org.opensearch.ml.common.memorycontainer.MemoryStrategy; +import org.opensearch.ml.common.memorycontainer.RemoteEmbeddingModel; import org.opensearch.ml.common.memorycontainer.RemoteStore; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerAction; @@ -331,15 +332,39 @@ private void indexMemoryContainer(MLMemoryContainer container, ActionListener listener) { // Check if we need to auto-create a connector - if (config.getRemoteStore() != null && config.getRemoteStore().getConnectorId() == null + if (config.getRemoteStore() != null + && config.getRemoteStore().getConnectorId() == null && config.getRemoteStore().getEndpoint() != null) { // Auto-create connector first createConnectorForRemoteStore(config.getRemoteStore(), ActionListener.wrap(connectorId -> { // Set the connector ID in the remote store config config.getRemoteStore().setConnectorId(connectorId); log.info("Auto-created connector with ID: {} for remote store", connectorId); - // Continue with normal validation - validateConfigurationInternal(config, listener); + + // Check if we need to auto-create embedding model + if (config.getRemoteStore().getEmbeddingModel() != null) { + RemoteStorageHelper + .createRemoteEmbeddingModel( + connectorId, + config.getRemoteStore().getEmbeddingModel(), + config.getRemoteStore().getCredential(), + client, + ActionListener.wrap(modelId -> { + // Set the embedding model ID in the remote store config + config.getRemoteStore().setEmbeddingModelId(modelId); + // Also set type and dimension from embedding model config + RemoteEmbeddingModel embModel = config.getRemoteStore().getEmbeddingModel(); + config.getRemoteStore().setEmbeddingModelType(embModel.getModelType()); + config.getRemoteStore().setEmbeddingDimension(embModel.getDimension()); + log.info("Auto-created embedding model with ID: {} in remote store", modelId); + // Continue with normal validation + validateConfigurationInternal(config, listener); + }, listener::onFailure) + ); + } else { + // Continue with normal validation + validateConfigurationInternal(config, listener); + } }, listener::onFailure)); } else { // Normal validation flow @@ -396,8 +421,6 @@ private void validateConfigurationInternal(MemoryConfiguration config, ActionLis }, listener::onFailure)); } - // ==================== Remote Storage Helper Methods ==================== - private void createRemoteSessionMemoryIndex(MemoryConfiguration configuration, String indexName, ActionListener listener) { String connectorId = configuration.getRemoteStore().getConnectorId(); RemoteStorageHelper.createRemoteSessionMemoryIndex(connectorId, indexName, configuration, mlIndicesHandler, client, listener); @@ -455,56 +478,57 @@ private void createRemoteLongTermMemoryIngestPipeline( */ private void createConnectorForRemoteStore(RemoteStore remoteStore, ActionListener listener) { try { - String connectorName = "auto_" + remoteStore.getType().name().toLowerCase() + "_connector_" + String connectorName = "auto_" + + remoteStore.getType().name().toLowerCase() + + "_connector_" + java.util.UUID.randomUUID().toString().substring(0, 8); - + // Build connector actions based on remote store type java.util.List actions = buildConnectorActions(remoteStore); - + // Get credential and parameters from remote store java.util.Map credential = remoteStore.getCredential(); java.util.Map parameters = remoteStore.getParameters(); - + // Determine protocol based on parameters or credential String protocol = determineProtocol(parameters, credential); - + // Create connector input - org.opensearch.ml.common.transport.connector.MLCreateConnectorInput connectorInput = - org.opensearch.ml.common.transport.connector.MLCreateConnectorInput.builder() + org.opensearch.ml.common.transport.connector.MLCreateConnectorInput connectorInput = + org.opensearch.ml.common.transport.connector.MLCreateConnectorInput + .builder() .name(connectorName) - .description("Auto-generated connector for " + remoteStore.getType() + " remote store") + .description("Auto-generated connector for " + remoteStore.getType() + " remote memory store") .version("1") .protocol(protocol) .parameters(parameters) .credential(credential) .actions(actions) .build(); - + // Create connector request - org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest request = + org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest request = new org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest(connectorInput); - + // Execute connector creation - client.execute( - org.opensearch.ml.common.transport.connector.MLCreateConnectorAction.INSTANCE, - request, - ActionListener.wrap( - response -> { + client + .execute( + org.opensearch.ml.common.transport.connector.MLCreateConnectorAction.INSTANCE, + request, + ActionListener.wrap(response -> { log.info("Successfully created connector: {}", response.getConnectorId()); listener.onResponse(response.getConnectorId()); - }, - e -> { + }, e -> { log.error("Failed to create connector for remote store", e); listener.onFailure(e); - } - ) - ); + }) + ); } catch (Exception e) { log.error("Error building connector for remote store", e); listener.onFailure(e); } } - + /** * Determines the protocol based on parameters and credentials */ @@ -524,7 +548,7 @@ private String determineProtocol(java.util.Map parameters, java. // Default to http (for basic auth or other) return "http"; } - + /** * Builds connector actions based on remote store type */ @@ -533,12 +557,12 @@ private java.util.List build String endpoint = remoteStore.getEndpoint(); java.util.Map parameters = remoteStore.getParameters(); java.util.Map credential = remoteStore.getCredential(); - + // Determine if AWS SigV4 or basic auth - boolean isAwsSigV4 = (parameters != null && parameters.containsKey("service_name")) + boolean isAwsSigV4 = (parameters != null && parameters.containsKey("service_name")) || (credential != null && (credential.containsKey("roleArn") || credential.containsKey("access_key"))); boolean isBasicAuth = credential != null && credential.containsKey("basic_auth_key"); - + // Common headers for JSON java.util.Map jsonHeaders = new java.util.HashMap<>(); jsonHeaders.put("content-type", "application/json"); @@ -548,37 +572,49 @@ private java.util.List build if (isBasicAuth) { jsonHeaders.put("Authorization", "Basic ${credential.basic_auth_key}"); } - + // Create ingest pipeline action - actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() - .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("create_ingest_pipeline") - .method("PUT") - .url(endpoint + "/_ingest/pipeline/${parameters.pipeline_name}") - .headers(jsonHeaders) - .requestBody("${parameters.input}") - .build()); - + actions + .add( + org.opensearch.ml.common.connector.ConnectorAction + .builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("create_ingest_pipeline") + .method("PUT") + .url(endpoint + "/_ingest/pipeline/${parameters.pipeline_name}") + .headers(jsonHeaders) + .requestBody("${parameters.input}") + .build() + ); + // Create index action - actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() - .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("create_index") - .method("PUT") - .url(endpoint + "/${parameters.index_name}") - .headers(jsonHeaders) - .requestBody("${parameters.input}") - .build()); - + actions + .add( + org.opensearch.ml.common.connector.ConnectorAction + .builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("create_index") + .method("PUT") + .url(endpoint + "/${parameters.index_name}") + .headers(jsonHeaders) + .requestBody("${parameters.input}") + .build() + ); + // Write doc action - actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() - .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("write_doc") - .method("POST") - .url(endpoint + "/${parameters.index_name}/_doc") - .headers(jsonHeaders) - .requestBody("${parameters.input}") - .build()); - + actions + .add( + org.opensearch.ml.common.connector.ConnectorAction + .builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("write_doc") + .method("POST") + .url(endpoint + "/${parameters.index_name}/_doc") + .headers(jsonHeaders) + .requestBody("${parameters.input}") + .build() + ); + // Bulk load action java.util.Map bulkHeaders = new java.util.HashMap<>(); bulkHeaders.put("content-type", "application/x-ndjson"); @@ -588,55 +624,89 @@ private java.util.List build if (isBasicAuth) { bulkHeaders.put("Authorization", "Basic ${credential.auth_key}"); } - - actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() - .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("bulk_load") - .method("POST") - .url(endpoint + "/_bulk") - .headers(bulkHeaders) - .requestBody("${parameters.input}") - .build()); - + + actions + .add( + org.opensearch.ml.common.connector.ConnectorAction + .builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("bulk_load") + .method("POST") + .url(endpoint + "/_bulk") + .headers(bulkHeaders) + .requestBody("${parameters.input}") + .build() + ); + // Search index action - actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() - .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("search_index") - .method("POST") - .url(endpoint + "/${parameters.index_name}/_search") - .headers(jsonHeaders) - .requestBody("${parameters.input}") - .build()); - + actions + .add( + org.opensearch.ml.common.connector.ConnectorAction + .builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("search_index") + .method("POST") + .url(endpoint + "/${parameters.index_name}/_search") + .headers(jsonHeaders) + .requestBody("${parameters.input}") + .build() + ); + // Get doc action - actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() - .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("get_doc") - .method("GET") - .url(endpoint + "/${parameters.index_name}/_doc/${parameters.doc_id}") - .headers(jsonHeaders) - .build()); - + actions + .add( + org.opensearch.ml.common.connector.ConnectorAction + .builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("get_doc") + .method("GET") + .url(endpoint + "/${parameters.index_name}/_doc/${parameters.doc_id}") + .headers(jsonHeaders) + .build() + ); + // Delete doc action - actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() - .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("delete_doc") - .method("DELETE") - .url(endpoint + "/${parameters.index_name}/_doc/${parameters.doc_id}") - .headers(jsonHeaders) - .build()); - + actions + .add( + org.opensearch.ml.common.connector.ConnectorAction + .builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("delete_doc") + .method("DELETE") + .url(endpoint + "/${parameters.index_name}/_doc/${parameters.doc_id}") + .headers(jsonHeaders) + .build() + ); + // Update doc action - POST /_update/ works on both OpenSearch and AOSS // Uses partial update with "doc" wrapper for flexibility - actions.add(org.opensearch.ml.common.connector.ConnectorAction.builder() - .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("update_doc") - .method("POST") - .url(endpoint + "/${parameters.index_name}/_update/${parameters.doc_id}") - .headers(jsonHeaders) - .requestBody("{ \"doc\": ${parameters.input:-} }") - .build()); - + actions + .add( + org.opensearch.ml.common.connector.ConnectorAction + .builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("update_doc") + .method("POST") + .url(endpoint + "/${parameters.index_name}/_update/${parameters.doc_id}") + .headers(jsonHeaders) + .requestBody("{ \"doc\": ${parameters.input:-} }") + .build() + ); + + // Register model action - for creating embedding models in remote AOSS + actions + .add( + org.opensearch.ml.common.connector.ConnectorAction + .builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name("register_model") + .method("POST") + .url(endpoint + "/_plugins/_ml/models/_register") + .headers(jsonHeaders) + .requestBody("${parameters.input}") + .build() + ); + return actions; } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java index b5ff675097..b3e7a8ca00 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java @@ -5,6 +5,7 @@ package org.opensearch.ml.helper; +import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.ml.common.CommonValue.CONNECTOR_ACTION_FIELD; import static org.opensearch.ml.common.CommonValue.ML_LONG_MEMORY_HISTORY_INDEX_MAPPING_PATH; @@ -23,12 +24,12 @@ import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.STRATEGY_ID_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.WORKING_MEMORY_INDEX; import static org.opensearch.ml.common.utils.ToolUtils.NO_ESCAPE_PARAMS; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS; import java.io.IOException; import java.util.HashMap; import java.util.Map; -import org.apache.commons.text.StringEscapeUtils; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetResponse; @@ -67,6 +68,7 @@ @Log4j2 public class RemoteStorageHelper { + private static final String REGISTER_MODEL_ACTION = "register_model"; private static final String CREATE_INGEST_PIPELINE_ACTION = "create_ingest_pipeline"; private static final String CREATE_INDEX_ACTION = "create_index"; private static final String WRITE_DOC_ACTION = "write_doc"; @@ -583,7 +585,7 @@ public static QueryBuilder buildFactSearchQuery( .append("{\"neural\":{\"") .append(MEMORY_EMBEDDING_FIELD) .append("\":{\"query_text\":\"") - .append(StringEscapeUtils.escapeJson(fact)) + .append(escapeJson(fact)) .append("\",\"model_id\":\"") .append(memoryConfig.getEmbeddingModelId()) .append("\"}}}"); @@ -593,7 +595,7 @@ public static QueryBuilder buildFactSearchQuery( .append("{\"neural_sparse\":{\"") .append(MEMORY_EMBEDDING_FIELD) .append("\":{\"query_text\":\"") - .append(StringEscapeUtils.escapeJson(fact)) + .append(escapeJson(fact)) .append("\",\"model_id\":\"") .append(memoryConfig.getEmbeddingModelId()) .append("\"}}}"); @@ -702,4 +704,167 @@ public static void createRemoteLongTermMemoryIndexWithPipeline( listener.onFailure(e); } } + + /** + * Creates an embedding model in remote AOSS collection + * + * @param connectorId The connector ID to use for remote storage + * @param embeddingModel The embedding model configuration + * @param remoteStoreCredential The remote store credentials (used if embedding model doesn't have its own) + * @param client The OpenSearch client + * @param listener The action listener that receives the created model ID + */ + public static void createRemoteEmbeddingModel( + String connectorId, + org.opensearch.ml.common.memorycontainer.RemoteEmbeddingModel embeddingModel, + Map remoteStoreCredential, + Client client, + ActionListener listener + ) { + try { + // Build model registration request body + String requestBody = buildEmbeddingModelRegistrationBody(embeddingModel, remoteStoreCredential); + + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INPUT_PARAM, requestBody); + parameters.put(NO_ESCAPE_PARAMS, INPUT_PARAM); + parameters.put(SKIP_VALIDATE_MISSING_PARAMETERS, "true"); + + // Execute the connector action with register_model action name + executeConnectorAction(connectorId, REGISTER_MODEL_ACTION, parameters, client, ActionListener.wrap(response -> { + // Parse model_id from response + String modelId = extractModelIdFromResponse(response); + log.info("Successfully created embedding model in remote store: {}", modelId); + listener.onResponse(modelId); + }, e -> { + log.error("Failed to create embedding model in remote store", e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error building embedding model registration request", e); + listener.onFailure(e); + } + } + + /** + * Builds the request body for embedding model registration in remote AOSS + */ + private static String buildEmbeddingModelRegistrationBody( + org.opensearch.ml.common.memorycontainer.RemoteEmbeddingModel embeddingModel, + Map remoteStoreCredential + ) { + // Build connector configuration based on provider + // String connectorConfig = buildEmbeddingModelConnectorConfig(embeddingModel, remoteStoreCredential); + + String provider = embeddingModel.getProvider(); + String connectorConfig = buildBedrockEmbeddingConnectorConfig(provider, embeddingModel, remoteStoreCredential); + + // Build model name from provider and model ID (e.g., "bedrock/amazon.titan-embed-text-v2:0") + String sanitizedProvider = provider.replace("/", "-"); // AOSS doesn't allow / in model name. + String sanitizedModelId = embeddingModel.getModelId().replace("/", "-"); + String modelName = String.format("%s__%s", sanitizedProvider, sanitizedModelId); + + return String + .format( + "{ \"function_name\": \"remote\", \"name\": \"%s\", \"description\": \"Auto-generated model\", \"connector\": %s }", + modelName, + connectorConfig + ); + } + + /** + * Builds Bedrock embedding connector configuration from template + */ + private static String buildBedrockEmbeddingConnectorConfig( + String provider, + org.opensearch.ml.common.memorycontainer.RemoteEmbeddingModel embeddingModel, + Map remoteStoreCredential + ) { + try { + // Load template from resource file, sample provider "bedrock/text_embedding" + String template = loadConnectorTemplate(provider, embeddingModel.getModelId()); + + // Get parameters and credential from embedding model + Map parameters = embeddingModel.getParameters(); + Map credential = embeddingModel.getCredential(); + + // Use embedding model credentials if provided, otherwise use remote store credentials + if (credential == null || credential.isEmpty()) { + credential = remoteStoreCredential; + } + + // Validate that parameters are provided + if (parameters == null || parameters.isEmpty()) { + throw new IllegalArgumentException("Bedrock embedding model requires parameters block"); + } + + // Parse the template as JSON and inject parameters and credential + String connectorConfig = injectParametersAndCredential(template, parameters, credential); + + return connectorConfig; + } catch (IOException e) { + log.error("Failed to load connector template", e); + throw new IllegalArgumentException("Failed to load connector template"); + } + } + + /** + * Injects parameters and credential into the connector template + */ + private static String injectParametersAndCredential(String template, Map parameters, Map credential) + throws IOException { + // Parse template as JSON + XContentParser parser = XContentHelper + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, new BytesArray(template), XContentType.JSON); + Map connectorMap = parser.mapOrdered(); + + // Inject parameters + connectorMap.put("parameters", parameters); + + // Inject credential + connectorMap.put("credential", credential); + + // Convert back to JSON string + return StringUtils.toJson(connectorMap); + } + + /** + * Loads connector template from resource file + * Path format: model-connectors///.json + * + * @param provider The model provider (e.g., "bedrock", "openai", "cohere") + * @param modelId The model identifier (e.g., "amazon.titan-embed-text-v2") + * @return The connector template as a string + */ + private static String loadConnectorTemplate(String provider, String modelId) throws IOException { + // Normalize model ID for file name (replace : with -) + String normalizedModelId = modelId.replace(":", "-"); + String path = String.format("model-connectors/%s/%s.json", provider, normalizedModelId); + + try (java.io.InputStream is = RemoteStorageHelper.class.getClassLoader().getResourceAsStream(path)) { + if (is == null) { + throw new IOException("Connector template not found: " + path); + } + return new String(is.readAllBytes(), java.nio.charset.StandardCharsets.UTF_8); + } + } + + /** + * Extracts model ID from model registration response + */ + private static String extractModelIdFromResponse(ModelTensorOutput response) { + try { + Map dataAsMap = response.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); + Object modelIdObj = dataAsMap.get("model_id"); + if (modelIdObj == null) { + throw new IllegalArgumentException("model_id not found in response"); + } + return modelIdObj.toString(); + } catch (Exception e) { + log.error("Failed to parse model_id from response", e); + throw new IllegalArgumentException("Failed to parse model_id from response", e); + } + } } From f0d2be1cd91b1759c7f1477ed286b52b7f9d75bb Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Wed, 22 Oct 2025 19:26:32 -0700 Subject: [PATCH 05/58] fix bulk load bug Signed-off-by: Yaliang Wu --- .../TransportCreateMemoryContainerAction.java | 83 +++++++++++-------- .../ml/helper/RemoteStorageHelper.java | 24 +++--- 2 files changed, 61 insertions(+), 46 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java index 8480a27c22..18ba79da08 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java @@ -7,8 +7,22 @@ import static org.opensearch.ml.common.CommonValue.ML_MEMORY_CONTAINER_INDEX; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE; +import static org.opensearch.ml.helper.RemoteStorageHelper.BULK_LOAD_ACTION; +import static org.opensearch.ml.helper.RemoteStorageHelper.CREATE_INDEX_ACTION; +import static org.opensearch.ml.helper.RemoteStorageHelper.CREATE_INGEST_PIPELINE_ACTION; +import static org.opensearch.ml.helper.RemoteStorageHelper.DELETE_DOC_ACTION; +import static org.opensearch.ml.helper.RemoteStorageHelper.GET_DOC_ACTION; +import static org.opensearch.ml.helper.RemoteStorageHelper.REGISTER_MODEL_ACTION; +import static org.opensearch.ml.helper.RemoteStorageHelper.SEARCH_INDEX_ACTION; +import static org.opensearch.ml.helper.RemoteStorageHelper.UPDATE_DOC_ACTION; +import static org.opensearch.ml.helper.RemoteStorageHelper.WRITE_DOC_ACTION; import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.DocWriteResponse; @@ -21,6 +35,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.memorycontainer.MLMemoryContainer; import org.opensearch.ml.common.memorycontainer.MemoryConfiguration; import org.opensearch.ml.common.memorycontainer.MemoryStrategy; @@ -481,14 +496,14 @@ private void createConnectorForRemoteStore(RemoteStore remoteStore, ActionListen String connectorName = "auto_" + remoteStore.getType().name().toLowerCase() + "_connector_" - + java.util.UUID.randomUUID().toString().substring(0, 8); + + UUID.randomUUID().toString().substring(0, 8); // Build connector actions based on remote store type - java.util.List actions = buildConnectorActions(remoteStore); + List actions = buildConnectorActions(remoteStore); // Get credential and parameters from remote store - java.util.Map credential = remoteStore.getCredential(); - java.util.Map parameters = remoteStore.getParameters(); + Map credential = remoteStore.getCredential(); + Map parameters = remoteStore.getParameters(); // Determine protocol based on parameters or credential String protocol = determineProtocol(parameters, credential); @@ -532,7 +547,7 @@ private void createConnectorForRemoteStore(RemoteStore remoteStore, ActionListen /** * Determines the protocol based on parameters and credentials */ - private String determineProtocol(java.util.Map parameters, java.util.Map credential) { + private String determineProtocol(Map parameters, Map credential) { // Check if service_name is in parameters (indicates AWS SigV4) if (parameters != null && parameters.containsKey("service_name")) { return "aws_sigv4"; @@ -552,11 +567,11 @@ private String determineProtocol(java.util.Map parameters, java. /** * Builds connector actions based on remote store type */ - private java.util.List buildConnectorActions(RemoteStore remoteStore) { - java.util.List actions = new java.util.ArrayList<>(); + private List buildConnectorActions(RemoteStore remoteStore) { + List actions = new ArrayList<>(); String endpoint = remoteStore.getEndpoint(); - java.util.Map parameters = remoteStore.getParameters(); - java.util.Map credential = remoteStore.getCredential(); + Map parameters = remoteStore.getParameters(); + Map credential = remoteStore.getCredential(); // Determine if AWS SigV4 or basic auth boolean isAwsSigV4 = (parameters != null && parameters.containsKey("service_name")) @@ -564,7 +579,7 @@ private java.util.List build boolean isBasicAuth = credential != null && credential.containsKey("basic_auth_key"); // Common headers for JSON - java.util.Map jsonHeaders = new java.util.HashMap<>(); + Map jsonHeaders = new HashMap<>(); jsonHeaders.put("content-type", "application/json"); if (isAwsSigV4) { jsonHeaders.put("x-amz-content-sha256", "required"); @@ -573,13 +588,27 @@ private java.util.List build jsonHeaders.put("Authorization", "Basic ${credential.basic_auth_key}"); } + // Register model action - for creating embedding models in remote AOSS + actions + .add( + org.opensearch.ml.common.connector.ConnectorAction + .builder() + .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) + .name(REGISTER_MODEL_ACTION) + .method("POST") + .url(endpoint + "/_plugins/_ml/models/_register") + .headers(jsonHeaders) + .requestBody("${parameters.input}") + .build() + ); + // Create ingest pipeline action actions .add( org.opensearch.ml.common.connector.ConnectorAction .builder() .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("create_ingest_pipeline") + .name(CREATE_INGEST_PIPELINE_ACTION) .method("PUT") .url(endpoint + "/_ingest/pipeline/${parameters.pipeline_name}") .headers(jsonHeaders) @@ -593,7 +622,7 @@ private java.util.List build org.opensearch.ml.common.connector.ConnectorAction .builder() .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("create_index") + .name(CREATE_INDEX_ACTION) .method("PUT") .url(endpoint + "/${parameters.index_name}") .headers(jsonHeaders) @@ -607,7 +636,7 @@ private java.util.List build org.opensearch.ml.common.connector.ConnectorAction .builder() .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("write_doc") + .name(WRITE_DOC_ACTION) .method("POST") .url(endpoint + "/${parameters.index_name}/_doc") .headers(jsonHeaders) @@ -616,13 +645,13 @@ private java.util.List build ); // Bulk load action - java.util.Map bulkHeaders = new java.util.HashMap<>(); + Map bulkHeaders = new HashMap<>(); bulkHeaders.put("content-type", "application/x-ndjson"); if (isAwsSigV4) { bulkHeaders.put("x-amz-content-sha256", "required"); } if (isBasicAuth) { - bulkHeaders.put("Authorization", "Basic ${credential.auth_key}"); + bulkHeaders.put("Authorization", "Basic ${credential.basic_auth_key}"); } actions @@ -630,7 +659,7 @@ private java.util.List build org.opensearch.ml.common.connector.ConnectorAction .builder() .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("bulk_load") + .name(BULK_LOAD_ACTION) .method("POST") .url(endpoint + "/_bulk") .headers(bulkHeaders) @@ -644,7 +673,7 @@ private java.util.List build org.opensearch.ml.common.connector.ConnectorAction .builder() .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("search_index") + .name(SEARCH_INDEX_ACTION) .method("POST") .url(endpoint + "/${parameters.index_name}/_search") .headers(jsonHeaders) @@ -658,7 +687,7 @@ private java.util.List build org.opensearch.ml.common.connector.ConnectorAction .builder() .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("get_doc") + .name(GET_DOC_ACTION) .method("GET") .url(endpoint + "/${parameters.index_name}/_doc/${parameters.doc_id}") .headers(jsonHeaders) @@ -671,7 +700,7 @@ private java.util.List build org.opensearch.ml.common.connector.ConnectorAction .builder() .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("delete_doc") + .name(DELETE_DOC_ACTION) .method("DELETE") .url(endpoint + "/${parameters.index_name}/_doc/${parameters.doc_id}") .headers(jsonHeaders) @@ -685,7 +714,7 @@ private java.util.List build org.opensearch.ml.common.connector.ConnectorAction .builder() .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("update_doc") + .name(UPDATE_DOC_ACTION) .method("POST") .url(endpoint + "/${parameters.index_name}/_update/${parameters.doc_id}") .headers(jsonHeaders) @@ -693,20 +722,6 @@ private java.util.List build .build() ); - // Register model action - for creating embedding models in remote AOSS - actions - .add( - org.opensearch.ml.common.connector.ConnectorAction - .builder() - .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) - .name("register_model") - .method("POST") - .url(endpoint + "/_plugins/_ml/models/_register") - .headers(jsonHeaders) - .requestBody("${parameters.input}") - .build() - ); - return actions; } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java index b3e7a8ca00..cd07eaaa7f 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java @@ -68,18 +68,18 @@ @Log4j2 public class RemoteStorageHelper { - private static final String REGISTER_MODEL_ACTION = "register_model"; - private static final String CREATE_INGEST_PIPELINE_ACTION = "create_ingest_pipeline"; - private static final String CREATE_INDEX_ACTION = "create_index"; - private static final String WRITE_DOC_ACTION = "write_doc"; - private static final String BULK_LOAD_ACTION = "bulk_load"; - private static final String SEARCH_INDEX_ACTION = "search_index"; - private static final String UPDATE_DOC_ACTION = "update_doc"; - private static final String GET_DOC_ACTION = "get_doc"; - private static final String DELETE_DOC_ACTION = "delete_doc"; - private static final String INDEX_NAME_PARAM = "index_name"; - private static final String DOC_ID_PARAM = "doc_id"; - private static final String INPUT_PARAM = "input"; + public static final String REGISTER_MODEL_ACTION = "register_model"; + public static final String CREATE_INGEST_PIPELINE_ACTION = "create_ingest_pipeline"; + public static final String CREATE_INDEX_ACTION = "create_index"; + public static final String WRITE_DOC_ACTION = "write_doc"; + public static final String BULK_LOAD_ACTION = "bulk_load"; + public static final String SEARCH_INDEX_ACTION = "search_index"; + public static final String GET_DOC_ACTION = "get_doc"; + public static final String DELETE_DOC_ACTION = "delete_doc"; + public static final String UPDATE_DOC_ACTION = "update_doc"; + public static final String INDEX_NAME_PARAM = "index_name"; + public static final String DOC_ID_PARAM = "doc_id"; + public static final String INPUT_PARAM = "input"; /** * Creates a memory index in remote storage using a connector From ee60d161aa748d6075880a808f1dac689a8cdf2a Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Thu, 23 Oct 2025 00:40:03 -0700 Subject: [PATCH 06/58] fix connector header for remote store Signed-off-by: Yaliang Wu --- .../amazon.titan-embed-text-v2-0.json | 4 +-- .../algorithms/remote/ConnectorUtils.java | 18 +++++++++++ .../TransportCreateMemoryContainerAction.java | 32 +++---------------- .../ml/helper/RemoteStorageHelper.java | 31 ++++++++++++++++++ 4 files changed, 55 insertions(+), 30 deletions(-) diff --git a/common/src/main/resources/model-connectors/bedrock/text_embedding/amazon.titan-embed-text-v2-0.json b/common/src/main/resources/model-connectors/bedrock/text_embedding/amazon.titan-embed-text-v2-0.json index 09d076a79b..71a7112b19 100644 --- a/common/src/main/resources/model-connectors/bedrock/text_embedding/amazon.titan-embed-text-v2-0.json +++ b/common/src/main/resources/model-connectors/bedrock/text_embedding/amazon.titan-embed-text-v2-0.json @@ -2,15 +2,13 @@ "name": "Amazon Bedrock Connector: embedding", "description": "Connector to bedrock embedding model", "version": 1, - "protocol": "aws_sigv4", "actions": [ { "action_type": "predict", "method": "POST", "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke", "headers": { - "content-type": "application/json", - "x-amz-content-sha256": "required" + "content-type": "application/json" }, "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}, \"normalize\": ${parameters.normalize}, \"embeddingTypes\": ${parameters.embeddingTypes} }", "pre_process_function": "connector.pre_process.bedrock.embedding", diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 88c0ce7587..315c55b8d2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -12,6 +12,8 @@ import static org.opensearch.ml.common.connector.ConnectorAction.COHERE; import static org.opensearch.ml.common.connector.ConnectorAction.OPENAI; import static org.opensearch.ml.common.connector.ConnectorAction.SAGEMAKER; +import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; +import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING; import static org.opensearch.ml.common.connector.MLPreProcessFunction.PROCESS_REMOTE_INFERENCE_INPUT; @@ -77,6 +79,22 @@ public class ConnectorUtils { signer = AwsV4HttpSigner.create(); } + /** + * Determines the protocol based on parameters and credentials + */ + public static String determineProtocol(Map parameters, Map credential) { + boolean hasAwsRegion = parameters != null && parameters.containsKey("region"); + boolean hasAwsServiceName = parameters != null && parameters.containsKey("service_name"); + boolean hasRoleArn = credential != null && credential.containsKey("roleArn"); + boolean hasAwsCredential = credential != null && credential.containsKey("access_key") && credential.containsKey("secret_key"); + // Check if service_name is in parameters (indicates AWS SigV4) + if (hasAwsRegion && hasAwsServiceName && (hasRoleArn || hasAwsCredential)) { + return AWS_SIGV4; + } + // Default to http (for basic auth or other) + return HTTP; + } + public static RemoteInferenceInputDataSet processInput( String action, MLInput mlInput, diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java index 18ba79da08..1b05d1a1b7 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.CommonValue.ML_MEMORY_CONTAINER_INDEX; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.determineProtocol; import static org.opensearch.ml.helper.RemoteStorageHelper.BULK_LOAD_ACTION; import static org.opensearch.ml.helper.RemoteStorageHelper.CREATE_INDEX_ACTION; import static org.opensearch.ml.helper.RemoteStorageHelper.CREATE_INGEST_PIPELINE_ACTION; @@ -544,26 +545,6 @@ private void createConnectorForRemoteStore(RemoteStore remoteStore, ActionListen } } - /** - * Determines the protocol based on parameters and credentials - */ - private String determineProtocol(Map parameters, Map credential) { - // Check if service_name is in parameters (indicates AWS SigV4) - if (parameters != null && parameters.containsKey("service_name")) { - return "aws_sigv4"; - } - // Check if roleArn is in credential (indicates AWS SigV4) - if (credential != null && credential.containsKey("roleArn")) { - return "aws_sigv4"; - } - // Check if access_key and secret_key are in credential (indicates AWS SigV4) - if (credential != null && credential.containsKey("access_key") && credential.containsKey("secret_key")) { - return "aws_sigv4"; - } - // Default to http (for basic auth or other) - return "http"; - } - /** * Builds connector actions based on remote store type */ @@ -574,17 +555,15 @@ private List buildConnectorActions(RemoteStore remoteStore) { Map credential = remoteStore.getCredential(); // Determine if AWS SigV4 or basic auth - boolean isAwsSigV4 = (parameters != null && parameters.containsKey("service_name")) - || (credential != null && (credential.containsKey("roleArn") || credential.containsKey("access_key"))); - boolean isBasicAuth = credential != null && credential.containsKey("basic_auth_key"); + String protocol = determineProtocol(parameters, credential); // Common headers for JSON Map jsonHeaders = new HashMap<>(); jsonHeaders.put("content-type", "application/json"); + boolean isAwsSigV4 = "aws_sigv4".equals(protocol); if (isAwsSigV4) { jsonHeaders.put("x-amz-content-sha256", "required"); - } - if (isBasicAuth) { + } else { // TODO: add more auth options jsonHeaders.put("Authorization", "Basic ${credential.basic_auth_key}"); } @@ -649,8 +628,7 @@ private List buildConnectorActions(RemoteStore remoteStore) { bulkHeaders.put("content-type", "application/x-ndjson"); if (isAwsSigV4) { bulkHeaders.put("x-amz-content-sha256", "required"); - } - if (isBasicAuth) { + } else { bulkHeaders.put("Authorization", "Basic ${credential.basic_auth_key}"); } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java index cd07eaaa7f..bda8589f18 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java @@ -12,6 +12,7 @@ import static org.opensearch.ml.common.CommonValue.ML_LONG_TERM_MEMORY_INDEX_MAPPING_PATH; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_SESSION_INDEX_MAPPING_PATH; import static org.opensearch.ml.common.CommonValue.ML_WORKING_MEMORY_INDEX_MAPPING_PATH; +import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.LONG_TERM_MEMORY_HISTORY_INDEX; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.LONG_TERM_MEMORY_INDEX; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; @@ -23,11 +24,13 @@ import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SESSION_INDEX; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.STRATEGY_ID_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.WORKING_MEMORY_INDEX; +import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.ToolUtils.NO_ESCAPE_PARAMS; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS; import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.opensearch.action.bulk.BulkResponse; @@ -57,6 +60,7 @@ import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.transport.client.Client; @@ -80,6 +84,8 @@ public class RemoteStorageHelper { public static final String INDEX_NAME_PARAM = "index_name"; public static final String DOC_ID_PARAM = "doc_id"; public static final String INPUT_PARAM = "input"; + public static final String HEADERS_FIELD = "headers"; + public static final String ACTIONS_FIELD = "actions"; /** * Creates a memory index in remote storage using a connector @@ -826,6 +832,31 @@ private static String injectParametersAndCredential(String template, Map action = (Map) actionObj; + if (action.containsKey(HEADERS_FIELD)) { + Map headers = (Map) action.get(HEADERS_FIELD); + if (isAwsSigv4) { + headers.put("x-amz-content-sha256", "required"); + } + headers.putAll(headersMap); + } else { + action.put(HEADERS_FIELD, headersMap); + } + } + } + // Convert back to JSON string return StringUtils.toJson(connectorMap); } From 80ec214faf27727884e61f3d508d83ff3377cb3b Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Thu, 23 Oct 2025 22:19:08 -0700 Subject: [PATCH 07/58] support ingest/search pipeline when create memory container Signed-off-by: Yaliang Wu --- .../memorycontainer/MemoryConfiguration.java | 71 ++++++++++++++----- .../MemoryContainerConstants.java | 2 + .../common/memorycontainer/RemoteStore.java | 32 ++++++++- .../TransportCreateMemoryContainerAction.java | 13 +++- .../memory/MemorySearchService.java | 5 +- .../memory/TransportSearchMemoriesAction.java | 4 ++ .../ml/helper/MemoryContainerHelper.java | 3 +- .../helper/MemoryContainerPipelineHelper.java | 34 +++++++-- .../ml/helper/RemoteStorageHelper.java | 55 ++++++++++++++ 9 files changed, 191 insertions(+), 28 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryConfiguration.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryConfiguration.java index 41a89e2d26..eff5ec952f 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryConfiguration.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryConfiguration.java @@ -16,6 +16,7 @@ import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.EMBEDDING_MODEL_TYPE_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.INDEX_PREFIX_INVALID_CHARACTERS_ERROR; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.INDEX_SETTINGS_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.INGEST_PIPELINE_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.INVALID_EMBEDDING_MODEL_TYPE_ERROR; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.LLM_ID_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MAX_INFER_SIZE_DEFAULT_VALUE; @@ -24,7 +25,7 @@ import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_INDEX_PREFIX_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.PARAMETERS_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.REMOTE_STORE_FIELD; -import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SEMANTIC_STORAGE_EMBEDDING_MODEL_ID_REQUIRED_ERROR; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SEARCH_PIPELINE_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SEMANTIC_STORAGE_EMBEDDING_MODEL_TYPE_REQUIRED_ERROR; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SPARSE_ENCODING_DIMENSION_NOT_ALLOWED_ERROR; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.STRATEGIES_FIELD; @@ -88,6 +89,10 @@ public class MemoryConfiguration implements ToXContentObject, Writeable { private String tenantId; private RemoteStore remoteStore; + // Optional pre-existing pipeline configuration for local storage + private String ingestPipeline; + private String searchPipeline; + public MemoryConfiguration( String indexPrefix, FunctionName embeddingModelType, @@ -102,7 +107,9 @@ public MemoryConfiguration( boolean disableSession, boolean useSystemIndex, String tenantId, - RemoteStore remoteStore + RemoteStore remoteStore, + String ingestPipeline, + String searchPipeline ) { // Validate first validateInputs(embeddingModelType, embeddingModelId, dimension, maxInferSize); @@ -131,6 +138,8 @@ public MemoryConfiguration( this.useSystemIndex = useSystemIndex; this.tenantId = tenantId; this.remoteStore = remoteStore; + this.ingestPipeline = ingestPipeline; + this.searchPipeline = searchPipeline; } private String buildIndexPrefix(String indexPrefix, boolean useSystemIndex) { @@ -175,6 +184,8 @@ public MemoryConfiguration(StreamInput input) throws IOException { if (input.readBoolean()) { this.remoteStore = new RemoteStore(input); } + this.ingestPipeline = input.readOptionalString(); + this.searchPipeline = input.readOptionalString(); } @Override @@ -213,6 +224,8 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalString(ingestPipeline); + out.writeOptionalString(searchPipeline); } @Override @@ -266,6 +279,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (remoteStore != null) { builder.field(REMOTE_STORE_FIELD, remoteStore); } + if (ingestPipeline != null) { + builder.field(INGEST_PIPELINE_FIELD, ingestPipeline); + } + if (searchPipeline != null) { + builder.field(SEARCH_PIPELINE_FIELD, searchPipeline); + } builder.endObject(); return builder; } @@ -285,6 +304,8 @@ public static MemoryConfiguration parse(XContentParser parser) throws IOExceptio boolean useSystemIndex = true; String tenantId = null; RemoteStore remoteStore = null; + String ingestPipeline = null; + String searchPipeline = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -339,6 +360,12 @@ public static MemoryConfiguration parse(XContentParser parser) throws IOExceptio case REMOTE_STORE_FIELD: remoteStore = RemoteStore.parse(parser); break; + case INGEST_PIPELINE_FIELD: + ingestPipeline = parser.text(); + break; + case SEARCH_PIPELINE_FIELD: + searchPipeline = parser.text(); + break; default: parser.skipChildren(); break; @@ -362,6 +389,8 @@ public static MemoryConfiguration parse(XContentParser parser) throws IOExceptio .useSystemIndex(useSystemIndex) .tenantId(tenantId) .remoteStore(remoteStore) + .ingestPipeline(ingestPipeline) + .searchPipeline(searchPipeline) .build(); } @@ -437,9 +466,9 @@ private static void validateEmbeddingConfiguration(FunctionName embeddingModelTy if (embeddingModelId != null && embeddingModelType == null) { throw new IllegalArgumentException(SEMANTIC_STORAGE_EMBEDDING_MODEL_TYPE_REQUIRED_ERROR); } - if (embeddingModelType != null && embeddingModelId == null) { - throw new IllegalArgumentException(SEMANTIC_STORAGE_EMBEDDING_MODEL_ID_REQUIRED_ERROR); - } + // if (embeddingModelType != null && embeddingModelId == null) { + // throw new IllegalArgumentException(SEMANTIC_STORAGE_EMBEDDING_MODEL_ID_REQUIRED_ERROR); + // } // If embedding model type is provided, validate it if (embeddingModelType != null) { @@ -499,20 +528,24 @@ public static void validateStrategiesRequireModels(MemoryConfiguration config) { hasEmbedding = config.getRemoteStore().getEmbeddingModelId() != null && config.getRemoteStore().getEmbeddingModelId() != null; } - if (!hasLlm || !hasEmbedding) { - String missing = !hasLlm && !hasEmbedding ? "LLM model and embedding model" - : !hasLlm ? "LLM model (llm_id)" - : "embedding model (embedding_model_id, embedding_model_type, dimension)"; - - throw new IllegalArgumentException( - String - .format( - "Strategies require both an LLM model and embedding model to be configured. Missing: %s. " - + "Strategies use LLM for fact extraction and embedding model for semantic search.", - missing - ) - ); - } + if (!hasLlm) { + throw new IllegalArgumentException("Strategies require an LLM model to be configured."); + } + + // if (!hasLlm || !hasEmbedding) { + // String missing = !hasLlm && !hasEmbedding ? "LLM model and embedding model" + // : !hasLlm ? "LLM model (llm_id)" + // : "embedding model (embedding_model_id, embedding_model_type, dimension)"; + // + // throw new IllegalArgumentException( + // String + // .format( + // "Strategies require both an LLM model and embedding model to be configured. Missing: %s. " + // + "Strategies use LLM for fact extraction and embedding model for semantic search.", + // missing + // ) + // ); + // } } /** diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java index 4253150bba..a4421a9daf 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java @@ -40,6 +40,8 @@ public class MemoryContainerConstants { public static final String ID_FIELD = "id"; public static final String ENABLED_FIELD = "enabled"; public static final String REMOTE_STORE_FIELD = "remote_store"; + public static final String INGEST_PIPELINE_FIELD = "ingest_pipeline"; + public static final String SEARCH_PIPELINE_FIELD = "search_pipeline"; // Default values public static final int MAX_INFER_SIZE_DEFAULT_VALUE = 5; diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java index 264af6c3b9..41d1262ace 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java @@ -38,6 +38,8 @@ public class RemoteStore implements ToXContentObject, Writeable { public static final String PARAMETERS_FIELD = "parameters"; public static final String CREDENTIAL_FIELD = "credential"; public static final String EMBEDDING_MODEL_FIELD = "embedding_model"; + public static final String INGEST_PIPELINE_FIELD = "ingest_pipeline"; + public static final String SEARCH_PIPELINE_FIELD = "search_pipeline"; private RemoteStoreType type; private String connectorId; @@ -53,6 +55,10 @@ public class RemoteStore implements ToXContentObject, Writeable { // Auto embedding model creation private RemoteEmbeddingModel embeddingModel; + // Optional pre-existing pipeline configuration + private String ingestPipeline; + private String searchPipeline; + @Builder public RemoteStore( RemoteStoreType type, @@ -63,7 +69,9 @@ public RemoteStore( String endpoint, Map parameters, Map credential, - RemoteEmbeddingModel embeddingModel + RemoteEmbeddingModel embeddingModel, + String ingestPipeline, + String searchPipeline ) { if (type == null) { throw new IllegalArgumentException("Invalid remote store type"); @@ -77,6 +85,8 @@ public RemoteStore( this.parameters = parameters != null ? new java.util.HashMap<>(parameters) : new java.util.HashMap<>(); this.credential = credential != null ? new java.util.HashMap<>(credential) : new java.util.HashMap<>(); this.embeddingModel = embeddingModel; + this.ingestPipeline = ingestPipeline; + this.searchPipeline = searchPipeline; } public RemoteStore(StreamInput input) throws IOException { @@ -101,6 +111,8 @@ public RemoteStore(StreamInput input) throws IOException { if (input.readBoolean()) { this.embeddingModel = new RemoteEmbeddingModel(input); } + this.ingestPipeline = input.readOptionalString(); + this.searchPipeline = input.readOptionalString(); } @Override @@ -134,6 +146,8 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalString(ingestPipeline); + out.writeOptionalString(searchPipeline); } @Override @@ -163,6 +177,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (embeddingModel != null) { builder.field(EMBEDDING_MODEL_FIELD, embeddingModel); } + if (ingestPipeline != null) { + builder.field(INGEST_PIPELINE_FIELD, ingestPipeline); + } + if (searchPipeline != null) { + builder.field(SEARCH_PIPELINE_FIELD, searchPipeline); + } // Don't serialize credentials for security - they are stored in the connector builder.endObject(); return builder; @@ -178,6 +198,8 @@ public static RemoteStore parse(XContentParser parser) throws IOException { Map parameters = new java.util.HashMap<>(); Map credential = new java.util.HashMap<>(); RemoteEmbeddingModel embeddingModel = null; + String ingestPipeline = null; + String searchPipeline = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -212,6 +234,12 @@ public static RemoteStore parse(XContentParser parser) throws IOException { case EMBEDDING_MODEL_FIELD: embeddingModel = RemoteEmbeddingModel.parse(parser); break; + case INGEST_PIPELINE_FIELD: + ingestPipeline = parser.text(); + break; + case SEARCH_PIPELINE_FIELD: + searchPipeline = parser.text(); + break; default: parser.skipChildren(); break; @@ -229,6 +257,8 @@ public static RemoteStore parse(XContentParser parser) throws IOException { .parameters(parameters) .credential(credential) .embeddingModel(embeddingModel) + .ingestPipeline(ingestPipeline) + .searchPipeline(searchPipeline) .build(); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java index 1b05d1a1b7..fed445e9d0 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java @@ -358,7 +358,9 @@ private void validateConfiguration(MemoryConfiguration config, ActionListener buildConnectorActions(RemoteStore remoteStore) { .actionType(org.opensearch.ml.common.connector.ConnectorAction.ActionType.EXECUTE) .name(SEARCH_INDEX_ACTION) .method("POST") - .url(endpoint + "/${parameters.index_name}/_search") + .url(endpoint + "/${parameters.index_name}/_search${parameters.search_pipeline:-}") .headers(jsonHeaders) .requestBody("${parameters.input}") .build() diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java index 8556629dee..2300ffbec9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java @@ -117,7 +117,10 @@ private void searchFactsSequentially( .searchSourceBuilder(searchSourceBuilder) .tenantId(tenantId) .build(); - + // TODO: add search pipeline support in SearchDataObjectRequest + // if (memoryConfig.getSearchPipeline() != null) { + // searchRequest.pipeline(memoryConfig.getSearchPipeline()); + // } memoryContainerHelper.searchData(memoryConfig, searchRequest, searchResponseActionListener); } else { String query = MemorySearchQueryBuilder diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java index ab48d287f9..dae30eda32 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java @@ -128,6 +128,10 @@ private void searchMemories( .searchSourceBuilder(input.getSearchSourceBuilder()) .tenantId(tenantId) .build(); + // TODO: add search pipeline support in SearchDataObjectRequest + // if (memoryConfig.getSearchPipeline() != null) { + // searchDataObjecRequest.pipeline(memoryConfig.getSearchPipeline()); + // } // Execute search ActionListener searchResponseActionListener = ActionListener.wrap(response -> { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java index d16ba9f6f0..e935ae23ea 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java @@ -305,7 +305,8 @@ public void searchDataFromRemoteStorage( ) { try { String connectorId = configuration.getRemoteStore().getConnectorId(); - RemoteStorageHelper.searchDocuments(connectorId, indexName, query, client, ActionListener.wrap(response -> { + String searchPipeline = configuration.getRemoteStore().getSearchPipeline(); + RemoteStorageHelper.searchDocuments(connectorId, indexName, query, searchPipeline, client, ActionListener.wrap(response -> { listener.onResponse(response); }, listener::onFailure)); } catch (Exception e) { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java index fd6682c477..44e7959f71 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java @@ -38,7 +38,8 @@ public final class MemoryContainerPipelineHelper { /** * Creates an ingest pipeline and long-term memory index. *

- * If embedding is configured, creates a text embedding pipeline first, + * If a pre-existing ingest pipeline is configured, uses that pipeline directly. + * Otherwise, if embedding is configured, creates a text embedding pipeline first, * then creates the long-term index with the pipeline attached. * If no embedding is configured, creates the index without a pipeline. * @@ -56,7 +57,13 @@ public static void createLongTermMemoryIngestPipeline( ActionListener listener ) { try { - if (config.getEmbeddingModelType() != null) { + // Check if user provided a pre-existing ingest pipeline at configuration level + if (config.getIngestPipeline() != null && !config.getIngestPipeline().isEmpty()) { + log.info("Using pre-existing ingest pipeline from configuration: {}", config.getIngestPipeline()); + // Use the user-provided pipeline directly + indicesHandler.createLongTermMemoryIndex(config.getIngestPipeline(), indexName, config, listener); + } else if (config.getEmbeddingModelType() != null) { + // Auto-create pipeline if embedding model is configured String pipelineName = indexName + "-embedding"; createTextEmbeddingPipeline(pipelineName, config, client, ActionListener.wrap(success -> { @@ -232,7 +239,8 @@ public static void createHistoryIndexIfEnabled( /** * Creates an ingest pipeline in remote storage and long-term memory index. *

- * If embedding is configured, creates a text embedding pipeline in the remote cluster first, + * If a pre-existing ingest pipeline is configured in remote_store, uses that pipeline directly. + * Otherwise, if embedding is configured, creates a text embedding pipeline in the remote cluster first, * then creates the long-term index with the pipeline attached. * If no embedding is configured, creates the index without a pipeline. * @@ -252,7 +260,25 @@ public static void createRemoteLongTermMemoryIngestPipeline( ActionListener listener ) { try { - if (config.getRemoteStore().getEmbeddingModelType() != null) { + RemoteStore remoteStore = config.getRemoteStore(); + + // Check if user provided a pre-existing ingest pipeline in remote_store + if (remoteStore.getIngestPipeline() != null && !remoteStore.getIngestPipeline().isEmpty()) { + log.info("Using pre-existing ingest pipeline from remote_store: {}", remoteStore.getIngestPipeline()); + // Use the user-provided pipeline directly + org.opensearch.ml.helper.RemoteStorageHelper + .createRemoteLongTermMemoryIndexWithIngestAndSearchPipeline( + connectorId, + indexName, + remoteStore.getIngestPipeline(), + remoteStore.getSearchPipeline(), + config, + indicesHandler, + client, + listener + ); + } else if (remoteStore.getEmbeddingModelType() != null) { + // Auto-create pipeline if embedding model is configured String pipelineName = indexName + "-embedding"; createRemoteTextEmbeddingPipeline(connectorId, pipelineName, config, client, ActionListener.wrap(success -> { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java index bda8589f18..a830d4a56a 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java @@ -21,6 +21,7 @@ import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.NAMESPACE_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.NAMESPACE_SIZE_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.OWNER_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SEARCH_PIPELINE_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SESSION_INDEX; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.STRATEGY_ID_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.WORKING_MEMORY_INDEX; @@ -399,6 +400,7 @@ public static void searchDocuments( String connectorId, String indexName, String query, + String searchPipeline, Client client, ActionListener listener ) { @@ -406,6 +408,9 @@ public static void searchDocuments( // Prepare parameters for connector execution Map parameters = new HashMap<>(); parameters.put(INDEX_NAME_PARAM, indexName); + if (searchPipeline != null) { + parameters.put(SEARCH_PIPELINE_FIELD, "?search_pipeline=" + searchPipeline); + } parameters.put(INPUT_PARAM, query); // Execute the connector action with search_index action name @@ -711,6 +716,56 @@ public static void createRemoteLongTermMemoryIndexWithPipeline( } } + public static void createRemoteLongTermMemoryIndexWithIngestAndSearchPipeline( + String connectorId, + String indexName, + String ingestPipelineName, + String searchPipelineName, + MemoryConfiguration memoryConfig, + MLIndicesHandler mlIndicesHandler, + Client client, + ActionListener listener + ) { + try { + String indexMapping = buildLongTermMemoryMapping(memoryConfig, mlIndicesHandler); + Map indexSettings = buildLongTermMemorySettings(memoryConfig); + + // Parse the mapping string to a Map + Map mappingMap = parseMappingToMap(indexMapping); + + // Build the request body for creating the index with pipeline + Map requestBody = new HashMap<>(); + requestBody.put("mappings", mappingMap); + + // Add settings with default pipeline (settings already have "index." prefix) + Map settings = new HashMap<>(indexSettings); + settings.put("index.default_pipeline", ingestPipelineName); + if (searchPipelineName != null) { + settings.put("index.search.default_pipeline", searchPipelineName); + } + requestBody.put("settings", settings); + + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(INPUT_PARAM, StringUtils.toJson(requestBody)); + parameters.put(CONNECTOR_ACTION_FIELD, CREATE_INDEX_ACTION); + + // Execute the connector action + executeConnectorAction(connectorId, parameters, client, ActionListener.wrap(response -> { + log.info("Successfully created remote long-term memory index with pipeline: {}", indexName); + listener.onResponse(true); + }, e -> { + log.error("Failed to create remote long-term memory index with pipeline: {}", indexName, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote long-term memory index creation with pipeline for: {}", indexName, e); + listener.onFailure(e); + } + } + /** * Creates an embedding model in remote AOSS collection * From c369aecc886e8d6814a29f60507584aa73ac766b Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 24 Oct 2025 15:47:53 -0700 Subject: [PATCH 08/58] Revert "agentic memory integration with agent framework (#4331)" This reverts commit fb90b1b0b97250f14e417a9d0aeff5a075cdbd99. --- .../org/opensearch/ml/common/MLAgentType.java | 2 +- .../opensearch/ml/common/MLMemoryType.java | 24 - .../opensearch/ml/common/agent/MLAgent.java | 16 +- .../transport/agent/MLAgentUpdateInput.java | 5 +- .../ml/common/MLAgentTypeTests.java | 4 +- .../agent/MLAgentUpdateInputTest.java | 2 +- .../agent/AgenticMemoryAdapter.java | 775 ----------------- .../agent/ChatHistoryTemplateEngine.java | 55 -- .../algorithms/agent/ChatMemoryAdapter.java | 124 --- .../engine/algorithms/agent/ChatMessage.java | 37 - .../algorithms/agent/MLAgentExecutor.java | 668 +++++--------- .../algorithms/agent/MLChatAgentRunner.java | 814 ++---------------- .../SimpleChatHistoryTemplateEngine.java | 81 -- .../ml/engine/memory/ChatMemoryAdapter.java | 0 .../ml/engine/memory/ChatMessage.java | 32 - .../agent/AgenticMemoryAdapterTest.java | 167 ---- .../agent/ChatMemoryAdapterTest.java | 115 --- .../agent/MLChatAgentRunnerTest.java | 132 --- .../ml/helper/MemoryContainerHelper.java | 5 +- 19 files changed, 318 insertions(+), 2740 deletions(-) delete mode 100644 common/src/main/java/org/opensearch/ml/common/MLMemoryType.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapter.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatHistoryTemplateEngine.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapter.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMessage.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/SimpleChatHistoryTemplateEngine.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMemoryAdapter.java delete mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMessage.java delete mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapterTest.java delete mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapterTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/MLAgentType.java b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java index 04a4b72014..2dd2614634 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLAgentType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java @@ -20,7 +20,7 @@ public static MLAgentType from(String value) { try { return MLAgentType.valueOf(value.toUpperCase(Locale.ROOT)); } catch (Exception e) { - throw new IllegalArgumentException(value + " is not a valid Agent Type"); + throw new IllegalArgumentException("Wrong Agent type"); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java b/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java deleted file mode 100644 index 31939ce1ca..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common; - -import java.util.Locale; - -public enum MLMemoryType { - CONVERSATION_INDEX, - AGENTIC_MEMORY; - - public static MLMemoryType from(String value) { - if (value != null) { - try { - return MLMemoryType.valueOf(value.toUpperCase(Locale.ROOT)); - } catch (Exception e) { - throw new IllegalArgumentException("Wrong Memory type"); - } - } - return null; - } -} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index ec73d73856..b66a23f11e 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -15,6 +15,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -112,7 +113,7 @@ private void validate() { String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH) ); } - MLAgentType.from(type); + validateMLAgentType(type); if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null) { throw new IllegalArgumentException("We need model information for the conversational agent type"); } @@ -129,6 +130,19 @@ private void validate() { } } + private void validateMLAgentType(String agentType) { + if (type == null) { + throw new IllegalArgumentException("Agent type can't be null"); + } else { + try { + MLAgentType.valueOf(agentType.toUpperCase(Locale.ROOT)); // Use toUpperCase() to allow case-insensitive matching + } catch (IllegalArgumentException e) { + // The typeStr does not match any MLAgentType, so throw a new exception with a clearer message. + throw new IllegalArgumentException(agentType + " is not a valid Agent Type"); + } + } + } + public MLAgent(StreamInput input) throws IOException { Version streamInputVersion = input.getVersion(); name = input.readString(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java index e85b3f4bdc..9a0d6002fd 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java @@ -26,7 +26,6 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; @@ -384,7 +383,9 @@ private void validate() { String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH) ); } - MLMemoryType.from(memoryType); + if (memoryType != null && !memoryType.equals("conversation_index")) { + throw new IllegalArgumentException(String.format("Invalid memory type: %s", memoryType)); + } if (tools != null) { Set toolNames = new HashSet<>(); for (MLToolSpec toolSpec : tools) { diff --git a/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java b/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java index 05f37c4992..ee15ca95fd 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java @@ -44,14 +44,14 @@ public void testFromWithMixedCase() { public void testFromWithInvalidType() { // This should throw an IllegalArgumentException exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage(" is not a valid Agent Type"); + exceptionRule.expectMessage("Wrong Agent type"); MLAgentType.from("INVALID_TYPE"); } @Test public void testFromWithEmptyString() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage(" is not a valid Agent Type"); + exceptionRule.expectMessage("Wrong Agent type"); // This should also throw an IllegalArgumentException MLAgentType.from(""); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java index 084f95d137..72eb035279 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java @@ -94,7 +94,7 @@ public void testValidationWithInvalidMemoryType() { IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { MLAgentUpdateInput.builder().agentId("test-agent-id").name("test-agent").memoryType("invalid_type").build(); }); - assertEquals("Wrong Memory type", e.getMessage()); + assertEquals("Invalid memory type: invalid_type", e.getMessage()); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapter.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapter.java deleted file mode 100644 index 6bf685bd7f..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapter.java +++ /dev/null @@ -1,775 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.engine.algorithms.agent; - -import java.time.Instant; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; - -import org.opensearch.action.search.SearchResponse; -import org.opensearch.core.action.ActionListener; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.ml.common.memorycontainer.MemoryType; -import org.opensearch.ml.common.memorycontainer.PayloadType; -import org.opensearch.ml.common.transport.memorycontainer.MLMemoryContainerGetAction; -import org.opensearch.ml.common.transport.memorycontainer.MLMemoryContainerGetRequest; -import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesAction; -import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesInput; -import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesRequest; -import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesAction; -import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesInput; -import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesRequest; -import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryAction; -import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryInput; -import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryRequest; -import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; -import org.opensearch.search.SearchHit; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.sort.SortOrder; -import org.opensearch.transport.client.Client; - -import lombok.extern.log4j.Log4j2; - -/** - * Adapter for Agentic Memory system to work with MLChatAgentRunner. - * - *

This adapter provides a bridge between the ML Chat Agent system and the Agentic Memory - * infrastructure, enabling intelligent conversation management and context retention.

- * - *

Memory Types Handled:

- *
    - *
  • WORKING memory: Recent conversation history and active interactions
  • - *
  • LONG_TERM memory: Extracted facts, summaries, and semantic insights
  • - *
- * - *

Key Features:

- *
    - *
  • Dual memory querying for comprehensive context retrieval
  • - *
  • Dynamic inference configuration based on memory container LLM settings
  • - *
  • Structured trace data storage for tool invocation tracking
  • - *
  • Robust error handling with fallback mechanisms
  • - *
- * - *

Usage Example:

- *
{@code
- * AgenticMemoryAdapter adapter = new AgenticMemoryAdapter(
- *     client, "memory-container-id", "session-123", "user-456"
- * );
- * 
- * // Retrieve conversation messages
- * adapter.getMessages(ActionListener.wrap(
- *     messages -> processMessages(messages),
- *     error -> handleError(error)
- * ));
- * 
- * // Save trace data
- * adapter.saveTraceData("search_tool", "query", "results", 
- *     "parent-id", 1, "search", listener);
- * }
- * - * @see ChatMemoryAdapter - * @see MLChatAgentRunner - */ -@Log4j2 -public class AgenticMemoryAdapter implements ChatMemoryAdapter { - private final Client client; - private final String memoryContainerId; - private final String sessionId; - private final String ownerId; - - /** - * Creates a new AgenticMemoryAdapter instance. - * - * @param client OpenSearch client for executing memory operations - * @param memoryContainerId Unique identifier for the memory container - * @param sessionId Session identifier for conversation context - * @param ownerId Owner/user identifier for access control - * @throws IllegalArgumentException if any required parameter is null - */ - public AgenticMemoryAdapter(Client client, String memoryContainerId, String sessionId, String ownerId) { - if (client == null) { - throw new IllegalArgumentException("Client cannot be null"); - } - if (memoryContainerId == null || memoryContainerId.trim().isEmpty()) { - throw new IllegalArgumentException("Memory container ID cannot be null or empty"); - } - if (sessionId == null || sessionId.trim().isEmpty()) { - throw new IllegalArgumentException("Session ID cannot be null or empty"); - } - if (ownerId == null || ownerId.trim().isEmpty()) { - throw new IllegalArgumentException("Owner ID cannot be null or empty"); - } - - this.client = client; - this.memoryContainerId = memoryContainerId; - this.sessionId = sessionId; - this.ownerId = ownerId; - } - - @Override - public void getMessages(ActionListener> listener) { - // Query both WORKING memory (recent conversations) and LONG_TERM memory - // (extracted facts) - // This provides both conversation history and learned context - - List allChatMessages = new ArrayList<>(); - AtomicInteger pendingRequests = new AtomicInteger(2); - - // 1. Get recent conversation history from WORKING memory - SearchSourceBuilder workingSearchBuilder = new SearchSourceBuilder() - .query( - QueryBuilders - .boolQuery() - .must(QueryBuilders.termQuery("namespace.session_id", sessionId)) - .must(QueryBuilders.termQuery("namespace.user_id", ownerId)) - ) - .sort("created_time", SortOrder.DESC) - .size(50); // Limit recent conversation history - - MLSearchMemoriesRequest workingRequest = MLSearchMemoriesRequest - .builder() - .mlSearchMemoriesInput( - MLSearchMemoriesInput - .builder() - .memoryContainerId(memoryContainerId) - .memoryType(MemoryType.WORKING) - .searchSourceBuilder(workingSearchBuilder) - .build() - ) - .build(); - - client.execute(MLSearchMemoriesAction.INSTANCE, workingRequest, ActionListener.wrap(workingResponse -> { - synchronized (allChatMessages) { - allChatMessages.addAll(parseAgenticMemoryResponse(workingResponse)); - if (pendingRequests.decrementAndGet() == 0) { - // Sort all chat messages by timestamp and return - allChatMessages.sort((a, b) -> a.getTimestamp().compareTo(b.getTimestamp())); - listener.onResponse(allChatMessages); - } - } - }, listener::onFailure)); - - // 2. Get relevant context from LONG_TERM memory (extracted facts) - SearchSourceBuilder longTermSearchBuilder = new SearchSourceBuilder() - .query( - QueryBuilders - .boolQuery() - .must(QueryBuilders.termQuery("namespace.session_id", sessionId)) - .must(QueryBuilders.termQuery("namespace.user_id", ownerId)) - .should(QueryBuilders.termQuery("strategy_type", "SUMMARY")) - .should(QueryBuilders.termQuery("strategy_type", "SEMANTIC")) - ) - .sort("created_time", SortOrder.DESC) - .size(10); // Limit context facts - - MLSearchMemoriesRequest longTermRequest = MLSearchMemoriesRequest - .builder() - .mlSearchMemoriesInput( - MLSearchMemoriesInput - .builder() - .memoryContainerId(memoryContainerId) - .memoryType(MemoryType.LONG_TERM) - .searchSourceBuilder(longTermSearchBuilder) - .build() - ) - .build(); - - client.execute(MLSearchMemoriesAction.INSTANCE, longTermRequest, ActionListener.wrap(longTermResponse -> { - synchronized (allChatMessages) { - allChatMessages.addAll(parseAgenticMemoryResponse(longTermResponse)); - if (pendingRequests.decrementAndGet() == 0) { - // Sort all chat messages by timestamp and return - allChatMessages.sort((a, b) -> a.getTimestamp().compareTo(b.getTimestamp())); - listener.onResponse(allChatMessages); - } - } - }, e -> { - // If long-term memory fails, still return working memory results - log.warn("Failed to retrieve long-term memory, continuing with working memory only", e); - synchronized (allChatMessages) { - if (pendingRequests.decrementAndGet() == 0) { - listener.onResponse(allChatMessages); - } - } - })); - } - - @Override - public String getConversationId() { - return sessionId; - } - - @Override - public String getMemoryContainerId() { - return memoryContainerId; - } - - @Override - public void saveInteraction( - String question, - String assistantResponse, - String parentId, - Integer traceNum, - String action, - ActionListener listener - ) { - if (listener == null) { - throw new IllegalArgumentException("ActionListener cannot be null"); - } - final String finalQuestion = question != null ? question : ""; - final String finalAssistantResponse = assistantResponse != null ? assistantResponse : ""; - - log - .info( - "AgenticMemoryAdapter.saveInteraction: Called with parentId: {}, action: {}, hasResponse: {}", - parentId, - action, - !finalAssistantResponse.isEmpty() - ); - - // If parentId is provided and we have a response, update the existing - // interaction - if (parentId != null && !finalAssistantResponse.isEmpty()) { - log.info("AgenticMemoryAdapter.saveInteraction: Updating existing interaction {} with final response", parentId); - - // Update the existing interaction with the complete conversation - Map updateFields = new HashMap<>(); - updateFields.put("response", finalAssistantResponse); - updateFields.put("input", finalQuestion); - - updateInteraction(parentId, updateFields, ActionListener.wrap(res -> { - log.info("AgenticMemoryAdapter.saveInteraction: Successfully updated interaction {}", parentId); - listener.onResponse(parentId); // Return the same interaction ID - }, ex -> { - log - .error( - "AgenticMemoryAdapter.saveInteraction: Failed to update interaction {}, falling back to create new", - parentId, - ex - ); - // Fallback to creating new interaction if update fails - createNewInteraction(finalQuestion, finalAssistantResponse, parentId, traceNum, action, listener); - })); - } else { - // Create new interaction (root interaction or when no parentId) - log.info("AgenticMemoryAdapter.saveInteraction: Creating new interaction - parentId: {}, action: {}", parentId, action); - createNewInteraction(finalQuestion, finalAssistantResponse, parentId, traceNum, action, listener); - } - } - - private void createNewInteraction( - String question, - String assistantResponse, - String parentId, - Integer traceNum, - String action, - ActionListener listener - ) { - List messages = Arrays - .asList( - MessageInput.builder().role("user").content(createTextContent(question)).build(), - MessageInput.builder().role("assistant").content(createTextContent(assistantResponse)).build() - ); - - // Create namespace map with proper String types - Map namespaceMap = new java.util.HashMap<>(); - namespaceMap.put("session_id", sessionId); - namespaceMap.put("user_id", ownerId); - - Map metadataMap = new java.util.HashMap<>(); - if (traceNum != null) { - metadataMap.put("trace_num", traceNum.toString()); - } - if (action != null) { - metadataMap.put("action", action); - } - if (parentId != null) { - metadataMap.put("parent_id", parentId); - } - - // Check if memory container has LLM ID configured to determine infer value - hasLlmIdConfigured(ActionListener.wrap(hasLlmId -> { - MLAddMemoriesInput input = MLAddMemoriesInput - .builder() - .memoryContainerId(memoryContainerId) - .messages(messages) - .namespace(namespaceMap) - .metadata(metadataMap) - .ownerId(ownerId) - .infer(hasLlmId) // Use dynamic infer based on LLM ID presence - .build(); - - MLAddMemoriesRequest request = new MLAddMemoriesRequest(input); - - client.execute(MLAddMemoriesAction.INSTANCE, request, ActionListener.wrap(addResponse -> { - log - .info( - "AgenticMemoryAdapter.createNewInteraction: Created interaction with ID: {}, sessionId: {}, action: {}, infer: {}", - addResponse.getWorkingMemoryId(), - addResponse.getSessionId(), - action, - hasLlmId - ); - listener.onResponse(addResponse.getWorkingMemoryId()); - }, listener::onFailure)); - }, ex -> { - log.warn("Failed to check LLM ID configuration for interaction, proceeding with infer=false", ex); - // Fallback to infer=false if we can't determine LLM ID status - MLAddMemoriesInput input = MLAddMemoriesInput - .builder() - .memoryContainerId(memoryContainerId) - .messages(messages) - .namespace(namespaceMap) - .metadata(metadataMap) - .ownerId(ownerId) - .infer(false) - .build(); - - MLAddMemoriesRequest request = new MLAddMemoriesRequest(input); - - client.execute(MLAddMemoriesAction.INSTANCE, request, ActionListener.wrap(addResponse -> { - log - .info( - "AgenticMemoryAdapter.createNewInteraction: Created interaction with ID: {}, sessionId: {}, action: {}, infer: false (fallback)", - addResponse.getWorkingMemoryId(), - addResponse.getSessionId(), - action - ); - listener.onResponse(addResponse.getWorkingMemoryId()); - }, listener::onFailure)); - })); - } - - /** - * Save trace data as structured tool invocation information in working memory. - * - *

This method stores detailed information about tool executions, including inputs, - * outputs, and contextual metadata. The data is stored with appropriate tags and - * namespace information for later retrieval and analysis.

- * - *

Important: This method always uses {@code infer=false} to prevent - * LLM-based long-term memory extraction from tool traces. Tool execution data is already - * structured and queryable, and extracting facts from intermediate steps would create - * fragmented, duplicate long-term memories. Semantic extraction happens only on final - * conversation interactions via {@link #saveInteraction}.

- * - * @param toolName Name of the tool that was executed (required, non-empty) - * @param toolInput Input parameters passed to the tool (nullable, defaults to empty string) - * @param toolOutput Output/response from the tool execution (nullable, defaults to empty string) - * @param parentMemoryId Parent memory ID to associate this trace with (nullable) - * @param traceNum Trace sequence number for ordering (nullable) - * @param action Action/origin identifier for categorization (nullable) - * @param listener ActionListener to handle the response with the created memory ID - * @throws IllegalArgumentException if toolName is null/empty or listener is null - * @see #saveInteraction for conversational data that triggers long-term memory extraction - */ - @Override - public void saveTraceData( - String toolName, - String toolInput, - String toolOutput, - String parentMemoryId, - Integer traceNum, - String action, - ActionListener listener - ) { - if (toolName == null || toolName.trim().isEmpty()) { - throw new IllegalArgumentException("Tool name cannot be null or empty"); - } - if (listener == null) { - throw new IllegalArgumentException("ActionListener cannot be null"); - } - final String finalToolName = toolName; - - // Create tool invocation structured data - Map toolInvocation = new HashMap<>(); - toolInvocation.put("tool_name", finalToolName); - toolInvocation.put("tool_input", toolInput != null ? toolInput : ""); - toolInvocation.put("tool_output", toolOutput != null ? toolOutput : ""); - - Map structuredData = new HashMap<>(); - structuredData.put("tool_invocations", List.of(toolInvocation)); - - // Create namespace map - Map namespaceMap = new HashMap<>(); - namespaceMap.put("session_id", sessionId); - namespaceMap.put("user_id", ownerId); - - // Create metadata map - Map metadataMap = new HashMap<>(); - metadataMap.put("status", "checkpoint"); - if (traceNum != null) { - metadataMap.put("trace_num", traceNum.toString()); - } - if (action != null) { - metadataMap.put("action", action); - } - if (parentMemoryId != null) { - metadataMap.put("parent_memory_id", parentMemoryId); - } - - // Create tags map with trace-specific information - Map tagsMap = new HashMap<>(); - tagsMap.put("data_type", "trace"); - - if (action != null) { - tagsMap.put("topic", action); - } - - /* - * IMPORTANT: Tool trace data uses infer=false to prevent long-term memory extraction - * - * Rationale: - * 1. Tool traces are intermediate execution steps, not final user-facing content - * 2. Running LLM inference on tool traces would create fragmented, low-quality long-term memories - * 3. Multiple tool executions in a single conversation would generate redundant/duplicate facts - * 4. Tool trace data is already structured (tool_name, tool_input, tool_output) and queryable - * 5. Final conversation interactions (saveInteraction) will trigger proper semantic extraction - * - * Example problem if infer=true: - * User: "What's the weather in Seattle?" - * - Tool trace saved → LLM extracts: "User queried Seattle" (incomplete context) - * - Final response saved → LLM extracts: "User asked about Seattle weather" (complete context) - * Result: Duplicate/conflicting long-term memories - * - * By setting infer=false for tool traces: - * - Tool execution data remains queryable via structured data search - * - Long-term memory extraction happens only on final, contextually complete interactions - * - Cleaner, more accurate long-term memory without duplication - * - Reduced LLM inference costs and processing overhead - */ - executeTraceDataSave(structuredData, namespaceMap, metadataMap, tagsMap, false, finalToolName, action, listener); - } - - /** - * Execute the actual trace data save operation. - * - *

Note: The infer parameter is kept for potential future use cases where selective - * inference on tool traces might be needed, but currently always receives false to - * prevent duplicate long-term memory extraction.

- * - * @param structuredData The structured data containing tool invocation information - * @param namespaceMap The namespace mapping for the memory - * @param metadataMap The metadata for the memory entry - * @param tagsMap The tags for the memory entry - * @param infer Whether to enable inference processing (currently always false for tool traces) - * @param toolName The name of the tool (for logging) - * @param action The action identifier (for logging) - * @param listener ActionListener to handle the response - */ - private void executeTraceDataSave( - Map structuredData, - Map namespaceMap, - Map metadataMap, - Map tagsMap, - boolean infer, - String toolName, - String action, - ActionListener listener - ) { - try { - MLAddMemoriesInput input = MLAddMemoriesInput - .builder() - .memoryContainerId(memoryContainerId) - .structuredData(structuredData) - .namespace(namespaceMap) - .metadata(metadataMap) - .tags(tagsMap) - .ownerId(ownerId) - .payloadType(PayloadType.DATA) - .infer(infer) - .build(); - - MLAddMemoriesRequest request = new MLAddMemoriesRequest(input); - - client.execute(MLAddMemoriesAction.INSTANCE, request, ActionListener.wrap(addResponse -> { - log - .info( - "AgenticMemoryAdapter.saveTraceData: Successfully saved trace data with ID: {}, toolName: {}, action: {}, infer: {}", - addResponse.getWorkingMemoryId(), - toolName, - action, - infer - ); - listener.onResponse(addResponse.getWorkingMemoryId()); - }, ex -> { - log - .error( - "AgenticMemoryAdapter.saveTraceData: Failed to save trace data for tool: {}, action: {}, infer: {}. Error: {}", - toolName, - action, - infer, - ex.getMessage(), - ex - ); - listener.onFailure(ex); - })); - } catch (Exception e) { - log - .error( - "AgenticMemoryAdapter.saveTraceData: Exception while building trace data save request for tool: {}, action: {}", - toolName, - action, - e - ); - listener.onFailure(e); - } - } - - /** - * Check if the memory container has an LLM ID configured for inference - * @param callback ActionListener to handle the result (true if LLM ID exists, false otherwise) - */ - private void hasLlmIdConfigured(ActionListener callback) { - MLMemoryContainerGetRequest getRequest = MLMemoryContainerGetRequest.builder().memoryContainerId(memoryContainerId).build(); - - client.execute(MLMemoryContainerGetAction.INSTANCE, getRequest, ActionListener.wrap(response -> { - boolean hasLlmId = response.getMlMemoryContainer().getConfiguration().getLlmId() != null; - log.info("Memory container {} has LLM ID configured: {}", memoryContainerId, hasLlmId); - callback.onResponse(hasLlmId); - }, ex -> { - log - .warn( - "Failed to get memory container {} configuration, defaulting infer to false. Error: {}", - memoryContainerId, - ex.getMessage(), - ex - ); - callback.onResponse(false); - })); - } - - private List> createTextContent(String text) { - return List.of(Map.of("type", "text", "text", text)); - } - - private List parseAgenticMemoryResponse(SearchResponse response) { - List chatMessages = new ArrayList<>(); - - for (SearchHit hit : response.getHits().getHits()) { - Map source = hit.getSourceAsMap(); - - // Parse working memory documents (conversational format) - if ("conversational".equals(source.get("payload_type"))) { - @SuppressWarnings("unchecked") - List> messages = (List>) source.get("messages"); - if (messages != null && messages.size() >= 2) { - // Extract user question and assistant response - String question = extractMessageText(messages.get(0)); // user message - String assistantResponse = extractMessageText(messages.get(1)); // assistant message - - if (question != null && assistantResponse != null) { - // Add user message - ChatMessage userMessage = ChatMessage - .builder() - .id(hit.getId() + "_user") - .timestamp(Instant.ofEpochMilli((Long) source.get("created_time"))) - .sessionId(getSessionIdFromNamespace(source)) - .role("user") - .content(question) - .contentType("text") - .origin("agentic_memory_working") - .metadata( - Map - .of( - "payload_type", - source.get("payload_type"), - "memory_container_id", - source.get("memory_container_id"), - "namespace", - source.get("namespace"), - "tags", - source.get("tags") - ) - ) - .build(); - chatMessages.add(userMessage); - - // Add assistant message - ChatMessage assistantMessage = ChatMessage - .builder() - .id(hit.getId() + "_assistant") - .timestamp(Instant.ofEpochMilli((Long) source.get("created_time"))) - .sessionId(getSessionIdFromNamespace(source)) - .role("assistant") - .content(assistantResponse) - .contentType("text") - .origin("agentic_memory_working") - .metadata( - Map - .of( - "payload_type", - source.get("payload_type"), - "memory_container_id", - source.get("memory_container_id"), - "namespace", - source.get("namespace"), - "tags", - source.get("tags") - ) - ) - .build(); - chatMessages.add(assistantMessage); - } - } - } - // Parse long-term memory documents (extracted facts) - else if (source.containsKey("memory") && source.containsKey("strategy_type")) { - String memory = (String) source.get("memory"); - String strategyType = (String) source.get("strategy_type"); - - // Convert extracted facts to chat message format for context - ChatMessage contextMessage = ChatMessage - .builder() - .id(hit.getId()) - .timestamp(Instant.ofEpochMilli((Long) source.get("created_time"))) - .sessionId(sessionId) // Use current session - .role("system") // System context message - .content("Context (" + strategyType + "): " + memory) // The extracted fact with context - .contentType("context") - .origin("agentic_memory_longterm") - .metadata( - Map - .of( - "strategy_type", - strategyType, - "strategy_id", - source.get("strategy_id"), - "memory_container_id", - source.get("memory_container_id") - ) - ) - .build(); - chatMessages.add(contextMessage); - } - } - - // Sort by timestamp to maintain chronological order - chatMessages.sort((a, b) -> a.getTimestamp().compareTo(b.getTimestamp())); - - return chatMessages; - } - - private String extractMessageText(Map message) { - if (message == null) - return null; - - @SuppressWarnings("unchecked") - List> content = (List>) message.get("content"); - if (content != null && !content.isEmpty()) { - Map firstContent = content.get(0); - return (String) firstContent.get("text"); - } - return null; - } - - private String getSessionIdFromNamespace(Map source) { - @SuppressWarnings("unchecked") - Map namespace = (Map) source.get("namespace"); - return namespace != null ? (String) namespace.get("session_id") : null; - } - - @Override - public void updateInteraction(String interactionId, java.util.Map updateFields, ActionListener listener) { - if (listener == null) { - throw new IllegalArgumentException("ActionListener cannot be null"); - } - if (interactionId == null || interactionId.trim().isEmpty()) { - listener.onFailure(new IllegalArgumentException("Interaction ID is required and cannot be empty")); - return; - } - if (updateFields == null || updateFields.isEmpty()) { - listener.onFailure(new IllegalArgumentException("Update fields are required and cannot be empty")); - return; - } - - try { - log - .info( - "AgenticMemoryAdapter.updateInteraction: CALLED - Updating interaction {} with fields: {} in memory container: {}", - interactionId, - updateFields.keySet(), - memoryContainerId - ); - - // Convert updateFields to the format expected by memory container API - Map updateContent = new java.util.HashMap<>(); - - // Handle the response field - this is the main field we need to update - if (updateFields.containsKey("response")) { - String response = (String) updateFields.get("response"); - String question = (String) updateFields.getOrDefault("input", ""); - - // For working memory updates, we need to provide the complete messages array - // with both user question and assistant response - List> messages = Arrays - .asList( - Map.of("role", "user", "content", createTextContent(question)), - Map.of("role", "assistant", "content", createTextContent(response)) - ); - - updateContent.put("messages", messages); - - log - .debug( - "AgenticMemoryAdapter.updateInteraction: Updating messages for interaction {} with question: '{}' and response length: {}", - interactionId, - question.length() > 50 ? question.substring(0, 50) + "..." : question, - response.length() - ); - } - - // Handle other fields that might be updated - if (updateFields.containsKey("additional_info")) { - updateContent.put("additional_info", updateFields.get("additional_info")); - } - - MLUpdateMemoryInput input = MLUpdateMemoryInput.builder().updateContent(updateContent).build(); - - MLUpdateMemoryRequest request = MLUpdateMemoryRequest - .builder() - .memoryContainerId(memoryContainerId) - .memoryType(MemoryType.WORKING) // We're updating working memory - .memoryId(interactionId) - .mlUpdateMemoryInput(input) - .build(); - - client.execute(MLUpdateMemoryAction.INSTANCE, request, ActionListener.wrap(updateResponse -> { - log - .debug( - "AgenticMemoryAdapter.updateInteraction: Successfully updated interaction {} in memory container: {}", - interactionId, - memoryContainerId - ); - listener.onResponse(null); - }, ex -> { - log - .error( - "AgenticMemoryAdapter.updateInteraction: Failed to update interaction {} in memory container {}", - interactionId, - memoryContainerId, - ex - ); - listener.onFailure(ex); - })); - - } catch (Exception e) { - log - .error( - "AgenticMemoryAdapter.updateInteraction: Exception while updating interaction {} in memory container {}", - interactionId, - memoryContainerId, - e - ); - listener.onFailure(e); - } - } - -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatHistoryTemplateEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatHistoryTemplateEngine.java deleted file mode 100644 index 80743ba3c5..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatHistoryTemplateEngine.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.engine.algorithms.agent; - -import java.util.List; -import java.util.Map; - -/** - * Enhanced template system for ChatMessage-based memory types. - * Supports flexible templating with role-based formatting and metadata access. - */ -public interface ChatHistoryTemplateEngine { - /** - * Build chat history from ChatMessage list using template - * @param messages List of ChatMessage objects - * @param template Template string with placeholders - * @param context Additional context variables - * @return Formatted chat history string - */ - String buildChatHistory(List messages, String template, Map context); - - /** - * Get default template for basic chat history formatting - * @return Default template string - */ - default String getDefaultTemplate() { - return "{{#each messages}}{{role}}: {{content}}\n{{/each}}"; - } - - /** - * Get role-based template with enhanced formatting - * @return Role-based template string - */ - default String getRoleBasedTemplate() { - return """ - {{#each messages}} - {{#if (eq role 'user')}} - Human: {{content}} - {{else if (eq role 'assistant')}} - Assistant: {{content}} - {{else if (eq role 'system')}} - System: {{content}} - {{else if (eq role 'tool')}} - Tool Result: {{content}} - {{/if}} - {{#if metadata.confidence}} - (Confidence: {{metadata.confidence}}) - {{/if}} - {{/each}} - """; - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapter.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapter.java deleted file mode 100644 index 88e952c806..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapter.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.engine.algorithms.agent; - -import java.util.List; - -import org.opensearch.core.action.ActionListener; - -/** - * Common interface for modern memory types supporting ChatMessage-based interactions. - * - *

This interface provides a unified abstraction for different memory backend implementations, - * enabling consistent interaction patterns across various memory storage systems. It supports - * both conversation management and detailed trace data storage for comprehensive agent behavior - * tracking.

- * - *

Supported Memory Types:

- *
    - *
  • Agentic Memory - Local cluster-based intelligent memory system
  • - *
  • Remote Agentic Memory - Distributed agentic memory implementation
  • - *
  • Bedrock AgentCore Memory - AWS Bedrock agent memory integration
  • - *
  • Future memory types - Extensible for additional implementations
  • - *
- * - *

Core Capabilities:

- *
    - *
  • Message retrieval in standardized ChatMessage format
  • - *
  • Conversation and session management
  • - *
  • Interaction persistence with metadata support
  • - *
  • Tool execution trace data storage
  • - *
  • Dynamic interaction updates
  • - *
- * - *

Note: ConversationIndex uses a separate legacy pipeline for backward compatibility - * and is not part of this modern interface hierarchy.

- * - * @see ChatMessage - * @see AgenticMemoryAdapter - */ -public interface ChatMemoryAdapter { - /** - * Retrieve conversation messages in ChatMessage format - * @param listener ActionListener to handle the response - */ - void getMessages(ActionListener> listener); - - /** - * Get the conversation/session identifier - * @return conversation ID or session ID - */ - String getConversationId(); - - /** - * This is the main memory container ID used to identify the memory container - * in the memory management system. - * @return - */ - String getMemoryContainerId(); - - /** - * Save interaction to memory (optional implementation) - * @param question User question - * @param response AI response - * @param parentId Parent interaction ID - * @param traceNum Trace number - * @param action Action performed - * @param listener ActionListener to handle the response - */ - default void saveInteraction( - String question, - String response, - String parentId, - Integer traceNum, - String action, - ActionListener listener - ) { - listener.onFailure(new UnsupportedOperationException("Save not implemented")); - } - - /** - * Update existing interaction with additional information - * @param interactionId Interaction ID to update - * @param updateFields Fields to update (e.g., final answer, additional info) - * @param listener ActionListener to handle the response - */ - default void updateInteraction(String interactionId, java.util.Map updateFields, ActionListener listener) { - listener.onFailure(new UnsupportedOperationException("Update interaction not implemented")); - } - - /** - * Save trace data as tool invocation data in working memory. - * - *

This method provides a standardized way to store detailed information about - * tool executions, enabling comprehensive tracking and analysis of agent behavior. - * Implementations should store this data in a structured format that supports - * later retrieval and analysis.

- * - *

Default implementation throws UnsupportedOperationException. Memory adapters - * that support trace data storage should override this method.

- * - * @param toolName Name of the tool that was executed (required) - * @param toolInput Input parameters passed to the tool (may be null) - * @param toolOutput Output/response from the tool execution (may be null) - * @param parentMemoryId Parent memory ID to associate this trace with (may be null) - * @param traceNum Trace sequence number for ordering (may be null) - * @param action Action/origin identifier for categorization (may be null) - * @param listener ActionListener to handle the response with created trace ID - * @throws UnsupportedOperationException if the implementation doesn't support trace data storage - */ - default void saveTraceData( - String toolName, - String toolInput, - String toolOutput, - String parentMemoryId, - Integer traceNum, - String action, - ActionListener listener - ) { - listener.onFailure(new UnsupportedOperationException("Save trace data not implemented")); - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMessage.java deleted file mode 100644 index 31dd72604d..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/ChatMessage.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.engine.algorithms.agent; - -import java.time.Instant; -import java.util.Map; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Getter; - -/** - * Enhanced memory message for chat agents - designed for extensibility. - * Supports multiple memory types: Agentic, Remote Agentic, Bedrock AgentCore, etc. - * - * Design Philosophy: - * - Text-first with rich metadata (hybrid approach) - * - Extensible for future multimodal content - * - Memory-type agnostic interface - * - Role-based message support - */ -@Builder -@AllArgsConstructor -@Getter -public class ChatMessage { - private String id; - private Instant timestamp; - private String sessionId; - private String role; // "user", "assistant", "system", "tool" - private String content; // Primary text content - private String contentType; // "text", "image", "tool_result", etc. (metadata) - private String origin; // "agentic_memory", "remote_agentic", "bedrock_agentcore", etc. - private Map metadata; // Rich content details and memory-specific data -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 4b44c55738..1594506cf4 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -27,7 +27,6 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; -import java.util.UUID; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchException; @@ -47,7 +46,6 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLAgentType; -import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; @@ -97,7 +95,6 @@ public class MLAgentExecutor implements Executable, SettingsChangeListener { public static final String MEMORY_ID = "memory_id"; - public static final String MEMORY_CONTAINER_ID = "memory_container_id"; public static final String QUESTION = "question"; public static final String PARENT_INTERACTION_ID = "parent_interaction_id"; public static final String REGENERATE_INTERACTION_ID = "regenerate_interaction_id"; @@ -177,211 +174,194 @@ public void execute(Input input, ActionListener listener, TransportChann if (MLIndicesHandler.doesMultiTenantIndexExist(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), ML_AGENT_INDEX)) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - sdkClient.getDataObjectAsync(getDataObjectRequest).whenComplete((response, throwable) -> { - context.restore(); - log.debug("Completed Get Agent Request, Agent id:{}", agentId); - if (throwable != null) { - Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); - if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { - log.error("Failed to get Agent index", cause); - listener.onFailure(new OpenSearchStatusException("Failed to get agent index", RestStatus.NOT_FOUND)); + sdkClient + .getDataObjectAsync(getDataObjectRequest, client.threadPool().executor("opensearch_ml_general")) + .whenComplete((response, throwable) -> { + context.restore(); + log.debug("Completed Get Agent Request, Agent id:{}", agentId); + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); + if (ExceptionsHelper.unwrap(cause, IndexNotFoundException.class) != null) { + log.error("Failed to get Agent index", cause); + listener.onFailure(new OpenSearchStatusException("Failed to get agent index", RestStatus.NOT_FOUND)); + } else { + log.error("Failed to get ML Agent {}", agentId, cause); + listener.onFailure(cause); + } } else { - log.error("Failed to get ML Agent {}", agentId, cause); - listener.onFailure(cause); - } - } else { - try { - GetResponse getAgentResponse = response.parser() == null ? null : GetResponse.fromXContent(response.parser()); - if (getAgentResponse != null && getAgentResponse.isExists()) { - try ( - XContentParser parser = jsonXContent - .createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, - getAgentResponse.getSourceAsString() - ) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLAgent mlAgent = MLAgent.parse(parser); - if (isMultiTenancyEnabled && !Objects.equals(tenantId, mlAgent.getTenantId())) { - listener - .onFailure( - new OpenSearchStatusException( - "You don't have permission to access this resource", - RestStatus.FORBIDDEN - ) + try { + GetResponse getAgentResponse = response.parser() == null + ? null + : GetResponse.fromXContent(response.parser()); + if (getAgentResponse != null && getAgentResponse.isExists()) { + try ( + XContentParser parser = jsonXContent + .createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, + getAgentResponse.getSourceAsString() + ) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLAgent mlAgent = MLAgent.parse(parser); + if (isMultiTenancyEnabled && !Objects.equals(tenantId, mlAgent.getTenantId())) { + listener + .onFailure( + new OpenSearchStatusException( + "You don't have permission to access this resource", + RestStatus.FORBIDDEN + ) + ); + } + MLMemorySpec memorySpec = mlAgent.getMemory(); + String memoryId = inputDataSet.getParameters().get(MEMORY_ID); + String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); + String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID); + String appType = mlAgent.getAppType(); + String question = inputDataSet.getParameters().get(QUESTION); + + if (parentInteractionId != null && regenerateInteractionId != null) { + throw new IllegalArgumentException( + "Provide either `parent_interaction_id` to update an existing interaction, or `regenerate_interaction_id` to create a new one." ); - } - MLMemorySpec memorySpec = mlAgent.getMemory(); - String memoryId; - if (Objects.equals(mlAgent.getMemory().getType(), MLMemoryType.CONVERSATION_INDEX.name())) { - memoryId = inputDataSet.getParameters().get(MEMORY_ID); - } else { - memoryId = inputDataSet.getParameters().get(MEMORY_CONTAINER_ID); - } - - String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); - String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID); - String appType = mlAgent.getAppType(); - String question = inputDataSet.getParameters().get(QUESTION); - - if (parentInteractionId != null && regenerateInteractionId != null) { - throw new IllegalArgumentException( - "Provide either `parent_interaction_id` to update an existing interaction, or `regenerate_interaction_id` to create a new one." - ); - } - - MLTask mlTask = MLTask - .builder() - .taskType(MLTaskType.AGENT_EXECUTION) - .functionName(FunctionName.AGENT) - .state(MLTaskState.CREATED) - .workerNodes(ImmutableList.of(clusterService.localNode().getId())) - .createTime(Instant.now()) - .lastUpdateTime(Instant.now()) - .async(false) - .tenantId(tenantId) - .build(); - - if (memoryId == null && regenerateInteractionId != null) { - throw new IllegalArgumentException("A memory ID must be provided to regenerate."); - } + } - // NEW: Handle AGENTIC_MEMORY type before ConversationIndex logic - if (memorySpec != null && "AGENTIC_MEMORY".equals(memorySpec.getType())) { - log.debug("Detected AGENTIC_MEMORY type - routing to agentic memory handler"); - handleAgenticMemory( - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - listener, - channel - ); - } - // EXISTING: ConversationIndex logic remains unchanged - else if (memorySpec != null - && memorySpec.getType() != null - && memoryFactoryMap.containsKey(memorySpec.getType()) - && (memoryId == null || parentInteractionId == null)) { - ConversationIndexMemory.Factory conversationIndexMemoryFactory = - (ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType()); - conversationIndexMemoryFactory.create(question, memoryId, appType, ActionListener.wrap(memory -> { - inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId()); - // get question for regenerate - if (regenerateInteractionId != null) { - log.info("Regenerate for existing interaction {}", regenerateInteractionId); - client - .execute( - GetInteractionAction.INSTANCE, - new GetInteractionRequest(regenerateInteractionId), - ActionListener.wrap(interactionRes -> { - inputDataSet - .getParameters() - .putIfAbsent(QUESTION, interactionRes.getInteraction().getInput()); - saveRootInteractionAndExecute( - listener, - memory, - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - channel + MLTask mlTask = MLTask + .builder() + .taskType(MLTaskType.AGENT_EXECUTION) + .functionName(FunctionName.AGENT) + .state(MLTaskState.CREATED) + .workerNodes(ImmutableList.of(clusterService.localNode().getId())) + .createTime(Instant.now()) + .lastUpdateTime(Instant.now()) + .async(false) + .tenantId(tenantId) + .build(); + + if (memoryId == null && regenerateInteractionId != null) { + throw new IllegalArgumentException("A memory ID must be provided to regenerate."); + } + if (memorySpec != null + && memorySpec.getType() != null + && memoryFactoryMap.containsKey(memorySpec.getType()) + && (memoryId == null || parentInteractionId == null)) { + ConversationIndexMemory.Factory conversationIndexMemoryFactory = + (ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType()); + conversationIndexMemoryFactory + .create(question, memoryId, appType, ActionListener.wrap(memory -> { + inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId()); + // get question for regenerate + if (regenerateInteractionId != null) { + log.info("Regenerate for existing interaction {}", regenerateInteractionId); + client + .execute( + GetInteractionAction.INSTANCE, + new GetInteractionRequest(regenerateInteractionId), + ActionListener.wrap(interactionRes -> { + inputDataSet + .getParameters() + .putIfAbsent(QUESTION, interactionRes.getInteraction().getInput()); + saveRootInteractionAndExecute( + listener, + memory, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel + ); + }, e -> { + log.error("Failed to get existing interaction for regeneration", e); + listener.onFailure(e); + }) ); - }, e -> { - log.error("Failed to get existing interaction for regeneration", e); - listener.onFailure(e); - }) - ); - } else { - saveRootInteractionAndExecute( - listener, - memory, - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - channel - ); - } - }, ex -> { - log.error("Failed to read conversation memory", ex); - listener.onFailure(ex); - })); - } else { - // For existing conversations, create memory instance using factory - if (memorySpec != null && memorySpec.getType() != null) { - ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap - .get(memorySpec.getType()); - if (factory != null) { - // memoryId exists, so create returns an object with existing memory, therefore name can - // be null - factory - .create( - null, - memoryId, - appType, - ActionListener - .wrap( - createdMemory -> executeAgent( - inputDataSet, - mlTask, - isAsync, - memoryId, - mlAgent, - outputs, - modelTensors, - listener, - createdMemory, - channel - ), - ex -> { - log.error("Failed to find memory with memory_id: {}", memoryId, ex); - listener.onFailure(ex); - } - ) - ); - return; + } else { + saveRootInteractionAndExecute( + listener, + memory, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel + ); + } + }, ex -> { + log.error("Failed to read conversation memory", ex); + listener.onFailure(ex); + })); + } else { + // For existing conversations, create memory instance using factory + if (memorySpec != null && memorySpec.getType() != null) { + ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap + .get(memorySpec.getType()); + if (factory != null) { + // memoryId exists, so create returns an object with existing memory, therefore name can + // be null + factory + .create( + null, + memoryId, + appType, + ActionListener + .wrap( + createdMemory -> executeAgent( + inputDataSet, + mlTask, + isAsync, + memoryId, + mlAgent, + outputs, + modelTensors, + listener, + createdMemory, + channel + ), + ex -> { + log.error("Failed to find memory with memory_id: {}", memoryId, ex); + listener.onFailure(ex); + } + ) + ); + return; + } } + executeAgent( + inputDataSet, + mlTask, + isAsync, + memoryId, + mlAgent, + outputs, + modelTensors, + listener, + null, + channel + ); } - executeAgent( - inputDataSet, - mlTask, - isAsync, - memoryId, - mlAgent, - outputs, - modelTensors, - listener, - null, - channel - ); + } catch (Exception e) { + log.error("Failed to parse ml agent {}", agentId, e); + listener.onFailure(e); } - } catch (Exception e) { - log.error("Failed to parse ml agent {}", agentId, e); - listener.onFailure(e); + } else { + listener + .onFailure( + new OpenSearchStatusException( + "Failed to find agent with the provided agent id: " + agentId, + RestStatus.NOT_FOUND + ) + ); } - } else { - listener - .onFailure( - new OpenSearchStatusException( - "Failed to find agent with the provided agent id: " + agentId, - RestStatus.NOT_FOUND - ) - ); + } catch (Exception e) { + log.error("Failed to get agent", e); + listener.onFailure(e); } - } catch (Exception e) { - log.error("Failed to get agent", e); - listener.onFailure(e); } - } - }); + }); } } else { listener.onFailure(new ResourceNotFoundException("Agent index not found")); @@ -476,7 +456,7 @@ private void executeAgent( List outputs, List modelTensors, ActionListener listener, - Object memory, // Accept both ConversationIndexMemory and AgenticMemoryAdapter + ConversationIndexMemory memory, TransportChannel channel ) { String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null; @@ -492,23 +472,12 @@ private void executeAgent( // If async is true, index ML task and return the taskID. Also add memoryID to the task if it exists if (isAsync) { Map agentResponse = new HashMap<>(); + if (memoryId != null && !memoryId.isEmpty()) { + agentResponse.put(MEMORY_ID, memoryId); + } - // Handle different memory types for response - if (memory instanceof AgenticMemoryAdapter) { - AgenticMemoryAdapter adapter = (AgenticMemoryAdapter) memory; - agentResponse.put(MEMORY_ID, adapter.getMemoryContainerId()); // memory_container_id - agentResponse.put("session_id", adapter.getConversationId()); // session_id - if (parentInteractionId != null && !parentInteractionId.isEmpty()) { - agentResponse.put(PARENT_INTERACTION_ID, parentInteractionId); // actual interaction ID - } - } else { - // ConversationIndex behavior (unchanged) - if (memoryId != null && !memoryId.isEmpty()) { - agentResponse.put(MEMORY_ID, memoryId); - } - if (parentInteractionId != null && !parentInteractionId.isEmpty()) { - agentResponse.put(PARENT_INTERACTION_ID, parentInteractionId); - } + if (parentInteractionId != null && !parentInteractionId.isEmpty()) { + agentResponse.put(PARENT_INTERACTION_ID, parentInteractionId); } mlTask.setResponse(agentResponse); mlTask.setAsync(true); @@ -566,7 +535,7 @@ private ActionListener createAgentActionListener( List modelTensors, String agentType, String parentInteractionId, - Object memory // Accept both ConversationIndexMemory and AgenticMemoryAdapter + ConversationIndexMemory memory ) { return ActionListener.wrap(output -> { if (output != null) { @@ -587,7 +556,7 @@ private ActionListener createAsyncTaskUpdater( List outputs, List modelTensors, String parentInteractionId, - Object memory // Accept both ConversationIndexMemory and AgenticMemoryAdapter + ConversationIndexMemory memory ) { String taskId = mlTask.getTaskId(); Map agentResponse = new HashMap<>(); @@ -614,7 +583,6 @@ private ActionListener createAsyncTaskUpdater( e -> log.error("Failed to update ML task {} with agent execution results", taskId) ) ); - }, ex -> { agentResponse.put(ERROR_MESSAGE, ex.getMessage()); @@ -743,259 +711,23 @@ public void indexMLTask(MLTask mlTask, ActionListener listener) { } } - /** - * Handle agentic memory type requests - */ - private void handleAgenticMemory( - RemoteInferenceInputDataSet inputDataSet, - MLTask mlTask, - boolean isAsync, - List outputs, - List modelTensors, - MLAgent mlAgent, - ActionListener listener, - TransportChannel channel - ) { - // Extract parameters - String memoryContainerId = inputDataSet.getParameters().get("memory_container_id"); - String sessionId = inputDataSet.getParameters().get("session_id"); - String ownerId = inputDataSet.getParameters().get("owner_id"); - String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); - - log.debug("MLAgentExecutor: Processing AGENTIC_MEMORY request with parameters: {}", inputDataSet.getParameters().keySet()); - log - .debug( - "Extracted agentic memory parameters - memoryContainerId: {}, sessionId: {}, ownerId: {}, parentInteractionId: {}", - memoryContainerId != null ? "present" : "null", - sessionId != null ? "present" : "null", - ownerId != null ? "present" : "null", - parentInteractionId != null ? "present" : "null" - ); - - // Parameter validation - if (memoryContainerId == null) { - log - .error( - "AGENTIC_MEMORY validation failed: memory_container_id is null. Available params: {}", - inputDataSet.getParameters().keySet() - ); - listener.onFailure(new IllegalArgumentException("memory_container_id is required for agentic memory")); - return; - } - - if (ownerId == null) { - log.error("AGENTIC_MEMORY validation failed: owner_id is null. Available params: {}", inputDataSet.getParameters().keySet()); - listener.onFailure(new IllegalArgumentException("owner_id is required for agentic memory")); - return; - } - - log.debug("AGENTIC_MEMORY parameter validation successful - memoryContainerId: {}, ownerId: {}", memoryContainerId, ownerId); - - // Session management (same pattern as ConversationIndex) - boolean isNewConversation = Strings.isEmpty(sessionId) || parentInteractionId == null; - log - .debug( - "Conversation type determination - sessionId: {}, parentInteractionId: {}, isNewConversation: {}", - sessionId != null ? "present" : "null", - parentInteractionId != null ? "present" : "null", - isNewConversation - ); - - if (isNewConversation) { - if (Strings.isEmpty(sessionId)) { - sessionId = UUID.randomUUID().toString(); // NEW conversation - log.debug("Generated new agentic memory session: {}", sessionId); - } else { - log.debug("Using provided session ID for new conversation: {}", sessionId); - } - } else { - log - .debug( - "Continuing existing agentic memory conversation - sessionId: {}, parentInteractionId: {}", - sessionId, - parentInteractionId - ); - } - - // Create AgenticMemoryAdapter - log - .debug( - "Creating AgenticMemoryAdapter with parameters - memoryContainerId: {}, sessionId: {}, ownerId: {}", - memoryContainerId, - sessionId, - ownerId - ); - try { - AgenticMemoryAdapter adapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); - log - .debug( - "AgenticMemoryAdapter created successfully - memoryContainerId: {}, sessionId: {}, conversationId: {}", - memoryContainerId, - sessionId, - adapter.getConversationId() - ); - - // Route to appropriate execution path - if (isNewConversation) { - // NEW conversation: create root interaction first - log - .debug( - "Execution path: NEW conversation - routing to saveRootInteractionAndExecuteAgentic for sessionId: {}", - sessionId - ); - saveRootInteractionAndExecuteAgentic( - listener, - adapter, - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - channel - ); - } else { - // EXISTING conversation: execute directly - log - .debug( - "Execution path: EXISTING conversation - routing to executeAgent for sessionId: {}, parentInteractionId: {}", - sessionId, - parentInteractionId - ); - executeAgent( - inputDataSet, - mlTask, - isAsync, - adapter.getMemoryContainerId(), - mlAgent, - outputs, - modelTensors, - listener, - adapter, - channel - ); - } - } catch (Exception ex) { - log - .error( - "AgenticMemoryAdapter creation failed - memoryContainerId: {}, sessionId: {}, ownerId: {}, error: {}", - memoryContainerId, - sessionId, - ownerId, - ex.getMessage(), - ex - ); - listener.onFailure(ex); - } - } - - /** - * Create root interaction for new agentic memory conversations (mirrors ConversationIndex pattern for tool tracing support) - */ - private void saveRootInteractionAndExecuteAgentic( - ActionListener listener, - AgenticMemoryAdapter adapter, - RemoteInferenceInputDataSet inputDataSet, - MLTask mlTask, - boolean isAsync, - List outputs, - List modelTensors, - MLAgent mlAgent, - TransportChannel channel - ) { - String question = inputDataSet.getParameters().get(QUESTION); - - log - .debug( - "Creating root interaction for agentic memory - memoryContainerId: {}, sessionId: {}, question: {}", - adapter.getMemoryContainerId(), - adapter.getConversationId(), - question != null ? "present" : "null" - ); - - // Create root interaction with empty response (same pattern as ConversationIndex) - // This enables tool tracing and proper interaction updating - adapter.saveInteraction(question, "", null, 0, "ROOT", ActionListener.wrap(interactionId -> { - log - .info( - "Root interaction created successfully for agentic memory - interactionId: {}, memoryContainerId: {}, sessionId: {}", - interactionId, - adapter.getMemoryContainerId(), - adapter.getConversationId() - ); - inputDataSet.getParameters().put(PARENT_INTERACTION_ID, interactionId); + private void updateInteractionWithFailure(String interactionId, ConversationIndexMemory memory, String errorMessage) { + if (interactionId != null && memory != null) { + String failureMessage = "Agent execution failed: " + errorMessage; + Map updateContent = new HashMap<>(); + updateContent.put(RESPONSE_FIELD, failureMessage); - log - .debug( - "Proceeding to executeAgent with root interaction - interactionId: {}, sessionId: {}", + memory + .getMemoryManager() + .updateInteraction( interactionId, - adapter.getConversationId() - ); - - executeAgent( - inputDataSet, - mlTask, - isAsync, - adapter.getMemoryContainerId(), // Use memory_container_id as memoryId for agentic memory - mlAgent, - outputs, - modelTensors, - listener, - adapter, - channel - ); - }, ex -> { - log - .error( - "Root interaction creation failed for agentic memory - memoryContainerId: {}, sessionId: {}, error: {}", - adapter.getMemoryContainerId(), - adapter.getConversationId(), - ex.getMessage(), - ex + updateContent, + ActionListener + .wrap( + res -> log.info("Updated interaction {} with failure message", interactionId), + e -> log.warn("Failed to update interaction {} with failure message", interactionId, e) + ) ); - listener.onFailure(ex); - })); - } - - private void updateInteractionWithFailure(String interactionId, Object memory, String errorMessage) { - if (interactionId != null && memory != null) { - if (memory instanceof ConversationIndexMemory) { - // Existing ConversationIndex error handling - ConversationIndexMemory conversationMemory = (ConversationIndexMemory) memory; - String failureMessage = "Agent execution failed: " + errorMessage; - Map updateContent = new HashMap<>(); - updateContent.put(RESPONSE_FIELD, failureMessage); - - conversationMemory - .getMemoryManager() - .updateInteraction( - interactionId, - updateContent, - ActionListener - .wrap( - res -> log.info("Updated interaction {} with failure message", interactionId), - e -> log.warn("Failed to update interaction {} with failure message", interactionId, e) - ) - ); - } else if (memory instanceof AgenticMemoryAdapter) { - // New agentic memory error handling - AgenticMemoryAdapter adapter = (AgenticMemoryAdapter) memory; - Map updateFields = new HashMap<>(); - updateFields.put("error", errorMessage); - - adapter - .updateInteraction( - interactionId, - updateFields, - ActionListener - .wrap( - res -> log.info("Updated agentic memory interaction {} with failure message", interactionId), - e -> log.warn("Failed to update agentic memory interaction {} with failure message", interactionId, e) - ) - ); - } else { - log.warn("Unknown memory type for error handling: {}", memory.getClass().getSimpleName()); - } } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index f22e295062..7e1a4050bd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -43,7 +43,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; @@ -58,7 +57,6 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; @@ -78,6 +76,8 @@ import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; +import org.opensearch.ml.repackage.com.google.common.collect.Lists; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; @@ -177,60 +177,78 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener functionCalling.configure(params); } + String memoryType = mlAgent.getMemory().getType(); + String memoryId = params.get(MLAgentExecutor.MEMORY_ID); + String appType = mlAgent.getAppType(); + String title = params.get(MLAgentExecutor.QUESTION); String chatHistoryPrefix = params.getOrDefault(PROMPT_CHAT_HISTORY_PREFIX, CHAT_HISTORY_PREFIX); String chatHistoryQuestionTemplate = params.get(CHAT_HISTORY_QUESTION_TEMPLATE); String chatHistoryResponseTemplate = params.get(CHAT_HISTORY_RESPONSE_TEMPLATE); int messageHistoryLimit = getMessageHistoryLimit(params); - createMemoryAdapter(mlAgent, params, ActionListener.wrap(memoryOrAdapter -> { - log.debug("createMemoryAdapter callback: memoryOrAdapter type = {}", memoryOrAdapter.getClass().getSimpleName()); - - if (memoryOrAdapter instanceof ConversationIndexMemory) { - // Existing ConversationIndex flow - zero changes - ConversationIndexMemory memory = (ConversationIndexMemory) memoryOrAdapter; - memory.getMessages(ActionListener.>wrap(r -> { - processLegacyInteractions( - r, - memory.getConversationId(), - memory, - mlAgent, - params, - inputParams, - chatHistoryPrefix, - chatHistoryQuestionTemplate, - chatHistoryResponseTemplate, - functionCalling, - listener - ); - }, e -> { - log.error("Failed to get chat history", e); - listener.onFailure(e); - }), messageHistoryLimit); - - } else if (memoryOrAdapter instanceof ChatMemoryAdapter) { - // Modern Pipeline - NEW ChatMessage processing - log.debug("Routing to modern ChatMemoryAdapter pipeline"); - ChatMemoryAdapter adapter = (ChatMemoryAdapter) memoryOrAdapter; - adapter.getMessages(ActionListener.wrap(chatMessages -> { - // Use NEW ChatMessage-based processing (no conversion to Interaction) - processModernChatMessages( - chatMessages, - adapter.getConversationId(), - adapter, // Add ChatMemoryAdapter parameter - mlAgent, - params, - inputParams, - functionCalling, - listener - ); - }, e -> { - log.error("Failed to get chat history from modern memory adapter", e); - listener.onFailure(e); - })); + ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); + conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.wrap(memory -> { + // TODO: call runAgent directly if messageHistoryLimit == 0 + memory.getMessages(ActionListener.>wrap(r -> { + List messageList = new ArrayList<>(); + for (Interaction next : r) { + String question = next.getInput(); + String response = next.getResponse(); + // As we store the conversation with empty response first and then update when have final answer, + // filter out those in-flight requests when run in parallel + if (Strings.isNullOrEmpty(response)) { + continue; + } + messageList + .add( + ConversationIndexMessage + .conversationIndexMessageBuilder() + .sessionId(memory.getConversationId()) + .question(question) + .response(response) + .build() + ); + } + if (!messageList.isEmpty()) { + if (chatHistoryQuestionTemplate == null) { + StringBuilder chatHistoryBuilder = new StringBuilder(); + chatHistoryBuilder.append(chatHistoryPrefix); + for (Message message : messageList) { + chatHistoryBuilder.append(message.toString()).append("\n"); + } + params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); + + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + inputParams.put(CHAT_HISTORY, chatHistoryBuilder.toString()); + } else { + List chatHistory = new ArrayList<>(); + for (Message message : messageList) { + Map messageParams = new HashMap<>(); + messageParams.put("question", processTextDoc(((ConversationIndexMessage) message).getQuestion())); + + StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); + String chatQuestionMessage = substitutor.replace(chatHistoryQuestionTemplate); + chatHistory.add(chatQuestionMessage); + + messageParams.clear(); + messageParams.put("response", processTextDoc(((ConversationIndexMessage) message).getResponse())); + substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); + String chatResponseMessage = substitutor.replace(chatHistoryResponseTemplate); + chatHistory.add(chatResponseMessage); + } + params.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); + params.put(NEW_CHAT_HISTORY, String.join(", ", chatHistory) + ", "); + + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + inputParams.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); + } + } - } else { - listener.onFailure(new IllegalArgumentException("Unsupported memory type: " + memoryOrAdapter.getClass())); - } + runAgent(mlAgent, params, listener, memory, memory.getConversationId(), functionCalling); + }, e -> { + log.error("Failed to get chat history", e); + listener.onFailure(e); + }), messageHistoryLimit); }, listener::onFailure)); } @@ -238,7 +256,7 @@ private void runAgent( MLAgent mlAgent, Map params, ActionListener listener, - Object memoryOrSessionId, // Can be Memory object or String sessionId + Memory memory, String sessionId, FunctionCalling functionCalling ) { @@ -249,71 +267,7 @@ private void runAgent( Map tools = new HashMap<>(); Map toolSpecMap = new HashMap<>(); createTools(toolFactories, params, allToolSpecs, tools, toolSpecMap, mlAgent); - - // Route to correct runReAct method based on memory type - if (memoryOrSessionId instanceof Memory) { - // Legacy ConversationIndex path - Memory actualMemory = (Memory) memoryOrSessionId; - runReAct( - mlAgent.getLlm(), - tools, - toolSpecMap, - params, - actualMemory, - sessionId, - mlAgent.getTenantId(), - listener, - functionCalling - ); - } else { - // Modern agentic memory path - create ChatMemoryAdapter - String memoryContainerId = params.get("memory_container_id"); - String ownerId = params.get("owner_id"); - - log - .debug( - "Agentic memory path: memoryContainerId={}, ownerId={}, sessionId={}, allParams={}", - memoryContainerId, - ownerId, - sessionId, - params.keySet() - ); - - if (memoryContainerId != null && ownerId != null) { - AgenticMemoryAdapter chatMemoryAdapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); - runReAct( - mlAgent.getLlm(), - tools, - toolSpecMap, - params, - chatMemoryAdapter, - sessionId, - mlAgent.getTenantId(), - listener, - functionCalling - ); - } else { - // Missing required parameters for agentic memory - log - .error( - "Agentic memory requested but missing required parameters. " - + "memory_container_id: {}, owner_id: {}, available params: {}", - memoryContainerId, - ownerId, - params.keySet() - ); - listener - .onFailure( - new IllegalArgumentException( - "Agentic memory requires 'memory_container_id' and 'owner_id' parameters. " - + "Provided: memory_container_id=" - + memoryContainerId - + ", owner_id=" - + ownerId - ) - ); - } - } + runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, mlAgent.getTenantId(), listener, functionCalling); }; // Fetch MCP tools and handle both success and failure cases @@ -433,32 +387,17 @@ private void runReAct( .build() ); - // Save trace data using appropriate memory adapter - if (memory instanceof ConversationIndexMemory) { - saveTraceData( - (ConversationIndexMemory) memory, - memory.getType(), - question, - thoughtResponse, - sessionId, - traceDisabled, - parentInteractionId, - traceNumber, - "LLM" - ); - } else if (memory instanceof ChatMemoryAdapter) { - saveTraceData( - (ChatMemoryAdapter) memory, - memory.getType(), - question, - thoughtResponse, - sessionId, - traceDisabled, - parentInteractionId, - traceNumber, - "LLM" - ); - } + saveTraceData( + conversationIndexMemory, + memory.getType(), + question, + thoughtResponse, + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + "LLM" + ); if (nextStepListener == null) { handleMaxIterationsReached( @@ -527,32 +466,17 @@ private void runReAct( ); scratchpadBuilder.append(toolResponse).append("\n\n"); - // Save trace data using appropriate memory adapter - if (memory instanceof ConversationIndexMemory) { - saveTraceData( - (ConversationIndexMemory) memory, - "ReAct", - lastActionInput.get(), - outputToOutputString(filteredOutput), - sessionId, - traceDisabled, - parentInteractionId, - traceNumber, - lastAction.get() - ); - } else if (memory instanceof ChatMemoryAdapter) { - saveTraceData( - (ChatMemoryAdapter) memory, - "ReAct", - lastActionInput.get(), - outputToOutputString(filteredOutput), - sessionId, - traceDisabled, - parentInteractionId, - traceNumber, - lastAction.get() - ); - } + saveTraceData( + conversationIndexMemory, + "ReAct", + lastActionInput.get(), + outputToOutputString(filteredOutput), + sessionId, + traceDisabled, + parentInteractionId, + traceNumber, + lastAction.get() + ); StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder), "${parameters.", "}"); newPrompt.set(substitutor.replace(finalPrompt)); @@ -657,7 +581,7 @@ private static void addToolOutputToAddtionalInfo( List list = (List) additionalInfo.get(toolOutputKey); list.add(outputString); } else { - additionalInfo.put(toolOutputKey, new ArrayList<>(Collections.singletonList(outputString))); + additionalInfo.put(toolOutputKey, Lists.newArrayList(outputString)); } } } @@ -781,45 +705,6 @@ public static void saveTraceData( } } - /** - * Overloaded saveTraceData method for ChatMemoryAdapter - */ - public static void saveTraceData( - ChatMemoryAdapter chatMemoryAdapter, - String memoryType, - String question, - String thoughtResponse, - String sessionId, - boolean traceDisabled, - String parentInteractionId, - AtomicInteger traceNumber, - String origin - ) { - if (chatMemoryAdapter != null && !traceDisabled) { - // Save trace data as tool invocation data in working memory - chatMemoryAdapter - .saveTraceData( - origin, // toolName (LLM, ReAct, etc.) - question, // toolInput - thoughtResponse, // toolOutput - parentInteractionId, // parentMemoryId - traceNumber.addAndGet(1), // traceNum - origin, // action - ActionListener - .wrap( - r -> log - .debug( - "Successfully saved trace data via ChatMemoryAdapter for session: {}, origin: {}", - sessionId, - origin - ), - e -> log - .warn("Failed to save trace data via ChatMemoryAdapter for session: {}, origin: {}", sessionId, origin, e) - ) - ); - } - } - private void sendFinalAnswer( String sessionId, ActionListener listener, @@ -874,51 +759,6 @@ private void sendFinalAnswer( } } - /** - * Overloaded sendFinalAnswer method for modern ChatMemoryAdapter pipeline - */ - private void sendFinalAnswer( - String sessionId, - ActionListener listener, - String question, - String parentInteractionId, - boolean verbose, - boolean traceDisabled, - List cotModelTensors, - ChatMemoryAdapter chatMemoryAdapter, // Modern parameter - AtomicInteger traceNumber, - Map additionalInfo, - String finalAnswer - ) { - // Send completion chunk for streaming - streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId); - - if (chatMemoryAdapter != null) { - String copyOfFinalAnswer = finalAnswer; - ActionListener saveTraceListener = ActionListener.wrap(r -> { - // For ChatMemoryAdapter, we don't have separate updateInteraction - // The saveInteraction method handles the complete saving - streamingWrapper - .sendFinalResponse( - sessionId, - listener, - parentInteractionId, - verbose, - cotModelTensors, - additionalInfo, - copyOfFinalAnswer - ); - }, listener::onFailure); - - // Use ChatMemoryAdapter's saveInteraction method - chatMemoryAdapter - .saveInteraction(question, finalAnswer, parentInteractionId, traceNumber.addAndGet(1), "LLM", saveTraceListener); - } else { - streamingWrapper - .sendFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, additionalInfo, finalAnswer); - } - } - public static List createModelTensors(String sessionId, String parentInteractionId) { List cotModelTensors = new ArrayList<>(); @@ -1023,7 +863,7 @@ public static void returnFinalResponse( ModelTensor .builder() .name("response") - .dataAsMap(Map.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) + .dataAsMap(ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) .build() ) ); @@ -1068,305 +908,6 @@ private void handleMaxIterationsReached( cleanUpResource(tools); } - /** - * Overloaded handleMaxIterationsReached method for ChatMemoryAdapter - */ - private void handleMaxIterationsReached( - String sessionId, - ActionListener listener, - String question, - String parentInteractionId, - boolean verbose, - boolean traceDisabled, - List traceTensors, - ChatMemoryAdapter chatMemoryAdapter, // Modern parameter - AtomicInteger traceNumber, - Map additionalInfo, - AtomicReference lastThought, - int maxIterations, - Map tools - ) { - String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) - ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) - : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); - sendFinalAnswer( - sessionId, - listener, - question, - parentInteractionId, - verbose, - traceDisabled, - traceTensors, - chatMemoryAdapter, // Use ChatMemoryAdapter instead of ConversationIndexMemory - traceNumber, - additionalInfo, - incompleteResponse - ); - cleanUpResource(tools); - } - - /** - * Complete runReAct method for modern ChatMemoryAdapter pipeline - * This method handles the new memory types (agentic, remote, bedrock, etc.) - * - * Full implementation with complete ReAct loop, tool execution, trace saving, and streaming. - */ - private void runReAct( - LLMSpec llm, - Map tools, - Map toolSpecMap, - Map parameters, - ChatMemoryAdapter chatMemoryAdapter, // Modern parameter - String sessionId, - String tenantId, - ActionListener listener, - FunctionCalling functionCalling - ) { - Map tmpParameters = constructLLMParams(llm, parameters); - String prompt = constructLLMPrompt(tools, tmpParameters); - tmpParameters.put(PROMPT, prompt); - final String finalPrompt = prompt; - - String question = tmpParameters.get(MLAgentExecutor.QUESTION); - String parentInteractionId = tmpParameters.get(MLAgentExecutor.PARENT_INTERACTION_ID); - boolean verbose = Boolean.parseBoolean(tmpParameters.getOrDefault(VERBOSE, "false")); - boolean traceDisabled = tmpParameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(tmpParameters.get(DISABLE_TRACE)); - - // Trace number - AtomicInteger traceNumber = new AtomicInteger(0); - - AtomicReference> lastLlmListener = new AtomicReference<>(); - AtomicReference lastThought = new AtomicReference<>(); - AtomicReference lastAction = new AtomicReference<>(); - AtomicReference lastActionInput = new AtomicReference<>(); - AtomicReference lastToolSelectionResponse = new AtomicReference<>(); - Map additionalInfo = new ConcurrentHashMap<>(); - Map lastToolParams = new ConcurrentHashMap<>(); - - StepListener firstListener = new StepListener(); - lastLlmListener.set(firstListener); - StepListener lastStepListener = firstListener; - - StringBuilder scratchpadBuilder = new StringBuilder(); - List interactions = new CopyOnWriteArrayList<>(); - - StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); - AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); - tmpParameters.put(PROMPT, newPrompt.get()); - List traceTensors = createModelTensors(sessionId, parentInteractionId); - int maxIterations = Integer.parseInt(tmpParameters.getOrDefault(MAX_ITERATION, DEFAULT_MAX_ITERATIONS)); - - for (int i = 0; i < maxIterations; i++) { - int finalI = i; - StepListener nextStepListener = (i == maxIterations - 1) ? null : new StepListener<>(); - - lastStepListener.whenComplete(output -> { - StringBuilder sessionMsgAnswerBuilder = new StringBuilder(); - if (finalI % 2 == 0) { - MLTaskResponse llmResponse = (MLTaskResponse) output; - ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput(); - List llmResponsePatterns = gson.fromJson(tmpParameters.get("llm_response_pattern"), List.class); - Map modelOutput = parseLLMOutput( - parameters, - tmpModelTensorOutput, - llmResponsePatterns, - tools.keySet(), - interactions, - functionCalling - ); - - streamingWrapper.fixInteractionRole(interactions); - String thought = String.valueOf(modelOutput.get(THOUGHT)); - String toolCallId = String.valueOf(modelOutput.get("tool_call_id")); - String action = String.valueOf(modelOutput.get(ACTION)); - String actionInput = String.valueOf(modelOutput.get(ACTION_INPUT)); - String thoughtResponse = modelOutput.get(THOUGHT_RESPONSE); - String finalAnswer = modelOutput.get(FINAL_ANSWER); - - if (finalAnswer != null) { - finalAnswer = finalAnswer.trim(); - sendFinalAnswer( - sessionId, - listener, - question, - parentInteractionId, - verbose, - traceDisabled, - traceTensors, - chatMemoryAdapter, // Use ChatMemoryAdapter instead of ConversationIndexMemory - traceNumber, - additionalInfo, - finalAnswer - ); - cleanUpResource(tools); - return; - } - - sessionMsgAnswerBuilder.append(thought); - lastThought.set(thought); - lastAction.set(action); - lastActionInput.set(actionInput); - lastToolSelectionResponse.set(thoughtResponse); - - traceTensors - .add( - ModelTensors - .builder() - .mlModelTensors(List.of(ModelTensor.builder().name("response").result(thoughtResponse).build())) - .build() - ); - - // Save trace data using ChatMemoryAdapter - saveTraceData( - chatMemoryAdapter, - "ChatMemoryAdapter", // Memory type for modern pipeline - question, - thoughtResponse, - sessionId, - traceDisabled, - parentInteractionId, - traceNumber, - "LLM" - ); - - if (nextStepListener == null) { - handleMaxIterationsReached( - sessionId, - listener, - question, - parentInteractionId, - verbose, - traceDisabled, - traceTensors, - chatMemoryAdapter, // Use ChatMemoryAdapter - traceNumber, - additionalInfo, - lastThought, - maxIterations, - tools - ); - return; - } - - if (tools.containsKey(action)) { - Map toolParams = constructToolParams( - tools, - toolSpecMap, - question, - lastActionInput, - action, - actionInput - ); - lastToolParams.clear(); - lastToolParams.putAll(toolParams); - runTool( - tools, - toolSpecMap, - tmpParameters, - (ActionListener) nextStepListener, - action, - actionInput, - toolParams, - interactions, - toolCallId, - functionCalling - ); - - } else { - String res = String.format(Locale.ROOT, "Failed to run the tool %s which is unsupported.", action); - StringSubstitutor substitutor = new StringSubstitutor( - Map.of(SCRATCHPAD, scratchpadBuilder.toString()), - "${parameters.", - "}" - ); - newPrompt.set(substitutor.replace(finalPrompt)); - tmpParameters.put(PROMPT, newPrompt.get()); - ((ActionListener) nextStepListener).onResponse(res); - } - } else { - Object filteredOutput = filterToolOutput(lastToolParams, output); - addToolOutputToAddtionalInfo(toolSpecMap, lastAction, additionalInfo, filteredOutput); - - String toolResponse = constructToolResponse( - tmpParameters, - lastAction, - lastActionInput, - lastToolSelectionResponse, - filteredOutput - ); - scratchpadBuilder.append(toolResponse).append("\n\n"); - - // Save trace data for tool response using ChatMemoryAdapter - saveTraceData( - chatMemoryAdapter, - "ReAct", - lastActionInput.get(), - outputToOutputString(filteredOutput), - sessionId, - traceDisabled, - parentInteractionId, - traceNumber, - lastAction.get() - ); - - StringSubstitutor substitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder), "${parameters.", "}"); - newPrompt.set(substitutor.replace(finalPrompt)); - tmpParameters.put(PROMPT, newPrompt.get()); - if (!interactions.isEmpty()) { - tmpParameters.put(INTERACTIONS, ", " + String.join(", ", interactions)); - } - - sessionMsgAnswerBuilder.append(outputToOutputString(filteredOutput)); - streamingWrapper.sendToolResponse(outputToOutputString(output), sessionId, parentInteractionId); - traceTensors - .add( - ModelTensors - .builder() - .mlModelTensors( - Collections - .singletonList( - ModelTensor.builder().name("response").result(sessionMsgAnswerBuilder.toString()).build() - ) - ) - .build() - ); - - if (finalI == maxIterations - 1) { - handleMaxIterationsReached( - sessionId, - listener, - question, - parentInteractionId, - verbose, - traceDisabled, - traceTensors, - chatMemoryAdapter, // Use ChatMemoryAdapter - traceNumber, - additionalInfo, - lastThought, - maxIterations, - tools - ); - return; - } - - if (nextStepListener != null) { - ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); - streamingWrapper.executeRequest(request, (ActionListener) nextStepListener); - } - } - }, listener::onFailure); - - if (i == 0) { - ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); - streamingWrapper.executeRequest(request, firstListener); - } - if (nextStepListener != null) { - lastStepListener = nextStepListener; - } - } - } - private void saveMessage( ConversationIndexMemory memory, String question, @@ -1392,171 +933,4 @@ private void saveMessage( memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener); } } - - /** - * Process modern ChatMessage format and build chat history using enhanced templates - */ - private void processModernChatMessages( - List chatMessages, - String sessionId, - ChatMemoryAdapter chatMemoryAdapter, // Add ChatMemoryAdapter parameter - MLAgent mlAgent, - Map params, - Map inputParams, - FunctionCalling functionCalling, - ActionListener listener - ) { - // Use new enhanced template system for ChatMessage - SimpleChatHistoryTemplateEngine templateEngine = new SimpleChatHistoryTemplateEngine(); - - // Filter out empty content messages (in-flight requests) - List validMessages = chatMessages - .stream() - .filter(msg -> msg.getContent() != null && !msg.getContent().trim().isEmpty()) - .toList(); - - if (!validMessages.isEmpty()) { - // Build chat history using enhanced template system - String chatHistory = templateEngine.buildSimpleChatHistory(validMessages); - params.put(CHAT_HISTORY, chatHistory); - inputParams.put(CHAT_HISTORY, chatHistory); - } - - // Run agent with modern processing (no Memory object needed) - runAgent(mlAgent, params, listener, sessionId, sessionId, functionCalling); - } - - /** - * Process legacy interactions (ConversationIndex) and build chat history, then run the agent - */ - private void processLegacyInteractions( - List interactions, - String sessionId, - ConversationIndexMemory memory, - MLAgent mlAgent, - Map params, - Map inputParams, - String chatHistoryPrefix, - String chatHistoryQuestionTemplate, - String chatHistoryResponseTemplate, - FunctionCalling functionCalling, - ActionListener listener - ) { - List messageList = new ArrayList<>(); - for (Interaction next : interactions) { - String question = next.getInput(); - String response = next.getResponse(); - // As we store the conversation with empty response first and then update when have final answer, - // filter out those in-flight requests when run in parallel - if (Strings.isNullOrEmpty(response)) { - continue; - } - messageList - .add( - ConversationIndexMessage - .conversationIndexMessageBuilder() - .sessionId(sessionId) - .question(question) - .response(response) - .build() - ); - } - - if (!messageList.isEmpty()) { - if (chatHistoryQuestionTemplate == null) { - StringBuilder chatHistoryBuilder = new StringBuilder(); - chatHistoryBuilder.append(chatHistoryPrefix); - for (Message message : messageList) { - chatHistoryBuilder.append(message.toString()).append("\n"); - } - params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); - - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate - inputParams.put(CHAT_HISTORY, chatHistoryBuilder.toString()); - } else { - List chatHistory = new ArrayList<>(); - for (Message message : messageList) { - Map messageParams = new HashMap<>(); - messageParams.put("question", processTextDoc(((ConversationIndexMessage) message).getQuestion())); - - StringSubstitutor substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); - String chatQuestionMessage = substitutor.replace(chatHistoryQuestionTemplate); - chatHistory.add(chatQuestionMessage); - - messageParams.clear(); - messageParams.put("response", processTextDoc(((ConversationIndexMessage) message).getResponse())); - substitutor = new StringSubstitutor(messageParams, CHAT_HISTORY_MESSAGE_PREFIX, "}"); - String chatResponseMessage = substitutor.replace(chatHistoryResponseTemplate); - chatHistory.add(chatResponseMessage); - } - params.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - params.put(NEW_CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate - inputParams.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - } - } - - runAgent(mlAgent, params, listener, memory != null ? memory : sessionId, sessionId, functionCalling); - } - - /** - * Create appropriate memory adapter based on memory type - */ - private void createMemoryAdapter(MLAgent mlAgent, Map params, ActionListener listener) { - String memoryType = mlAgent.getMemory().getType(); - MLMemoryType type = MLMemoryType.from(memoryType); - - log.debug("MLChatAgentRunner.createMemoryAdapter: memoryType={}, params={}", memoryType, params.keySet()); - - switch (type) { - case CONVERSATION_INDEX: - // Keep existing flow - no adapter needed (zero risk approach) - ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); - String title = params.get(MLAgentExecutor.QUESTION); - String memoryId = params.get(MLAgentExecutor.MEMORY_ID); - String appType = mlAgent.getAppType(); - - factory.create(title, memoryId, appType, ActionListener.wrap(conversationMemory -> { - // Return ConversationIndexMemory directly - no conversion needed - listener.onResponse(conversationMemory); - }, listener::onFailure)); - break; - - case AGENTIC_MEMORY: - // New agentic memory path - String memoryContainerId = params.get("memory_container_id"); - String sessionId = params.get("session_id"); - String ownerId = params.get("owner_id"); // From user context - - log.debug("AGENTIC_MEMORY path: memoryContainerId={}, sessionId={}, ownerId={}", memoryContainerId, sessionId, ownerId); - - // Validate required parameters - if (memoryContainerId == null) { - log.error("AGENTIC_MEMORY validation failed: memory_container_id is null. Available params: {}", params.keySet()); - listener.onFailure(new IllegalArgumentException("memory_container_id is required for agentic memory")); - return; - } - - // Session management: same pattern as ConversationIndex - if (Strings.isEmpty(sessionId)) { - // CREATE NEW: Generate new session ID if not provided - sessionId = UUID.randomUUID().toString(); - log.debug("Created new agentic memory session: {}", sessionId); - } - // USE EXISTING: If sessionId provided, use it directly - - AgenticMemoryAdapter adapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); - log.debug("Created AgenticMemoryAdapter successfully: memoryContainerId={}, sessionId={}", memoryContainerId, sessionId); - listener.onResponse(adapter); - break; - - default: - // Future memory types will be added here: - // - REMOTE_AGENTIC_MEMORY: RemoteAgenticMemoryAdapter (similar format, different location) - // - BEDROCK_AGENTCORE: BedrockAgentCoreMemoryAdapter (format adapted in adapter) - // All future types will use modern ChatMessage pipeline - listener.onFailure(new IllegalArgumentException("Unsupported memory type: " + memoryType)); - } - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/SimpleChatHistoryTemplateEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/SimpleChatHistoryTemplateEngine.java deleted file mode 100644 index 33399208e4..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/SimpleChatHistoryTemplateEngine.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.engine.algorithms.agent; - -import java.util.List; -import java.util.Map; - -/** - * Simple implementation of ChatHistoryTemplateEngine. - * Provides basic template functionality for ChatMessage formatting. - * - * This is a simplified implementation that supports: - * - Role-based message formatting - * - Basic placeholder replacement - * - Content type awareness - * - * Future versions can implement more advanced template engines (Handlebars, etc.) - */ -public class SimpleChatHistoryTemplateEngine implements ChatHistoryTemplateEngine { - - @Override - public String buildChatHistory(List messages, String template, Map context) { - if (messages == null || messages.isEmpty()) { - return ""; - } - - // For now, use a simple approach - build chat history with role-based formatting - StringBuilder chatHistory = new StringBuilder(); - - for (ChatMessage message : messages) { - String formattedMessage = formatMessage(message); - chatHistory.append(formattedMessage).append("\n"); - } - - return chatHistory.toString().trim(); - } - - /** - * Format a single ChatMessage based on its role and content type - */ - private String formatMessage(ChatMessage message) { - String role = message.getRole(); - String content = message.getContent(); - String contentType = message.getContentType(); - - // Role-based formatting - String prefix = switch (role) { - case "user" -> "Human: "; - case "assistant" -> "Assistant: "; - case "system" -> "System: "; - case "tool" -> "Tool Result: "; - default -> role + ": "; - }; - - // Content type awareness - String formattedContent = content; - if ("image".equals(contentType)) { - formattedContent = "[Image: " + content + "]"; - } else if ("tool_result".equals(contentType)) { - Map metadata = message.getMetadata(); - if (metadata != null && metadata.containsKey("tool_name")) { - formattedContent = "Tool " + metadata.get("tool_name") + ": " + content; - } - } else if ("context".equals(contentType)) { - // Context messages (like from long-term memory) get special formatting - formattedContent = "[Context] " + content; - } - - return prefix + formattedContent; - } - - /** - * Build chat history using default simple formatting - */ - public String buildSimpleChatHistory(List messages) { - return buildChatHistory(messages, getDefaultTemplate(), Map.of()); - } -} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMemoryAdapter.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMemoryAdapter.java deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMessage.java deleted file mode 100644 index 0111642129..0000000000 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ChatMessage.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.engine.memory; - -import java.util.Map; - -/** - * Interface for chat messages in the unified memory system. - * Provides a common abstraction for messages across different memory implementations. - */ -public interface ChatMessage { - /** - * Get the role of the message sender - * @return role such as "user", "assistant", "system" - */ - String getRole(); - - /** - * Get the content of the message - * @return message content - */ - String getContent(); - - /** - * Get additional metadata associated with the message - * @return metadata map - */ - Map getMetadata(); -} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapterTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapterTest.java deleted file mode 100644 index f69f7d71e2..0000000000 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgenticMemoryAdapterTest.java +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.engine.algorithms.agent; - -import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; - -import java.util.HashMap; -import java.util.Map; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.opensearch.core.action.ActionListener; -import org.opensearch.transport.client.Client; - -/** - * Unit tests for AgenticMemoryAdapter. - */ -public class AgenticMemoryAdapterTest { - - @Mock - private Client client; - - private AgenticMemoryAdapter adapter; - private final String memoryContainerId = "test-memory-container"; - private final String sessionId = "test-session"; - private final String ownerId = "test-owner"; - - @Before - public void setUp() { - MockitoAnnotations.openMocks(this); - adapter = new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ownerId); - } - - @Test(expected = IllegalArgumentException.class) - public void testConstructorWithNullClient() { - new AgenticMemoryAdapter(null, memoryContainerId, sessionId, ownerId); - } - - @Test(expected = IllegalArgumentException.class) - public void testConstructorWithNullMemoryContainerId() { - new AgenticMemoryAdapter(client, null, sessionId, ownerId); - } - - @Test(expected = IllegalArgumentException.class) - public void testConstructorWithEmptyMemoryContainerId() { - new AgenticMemoryAdapter(client, "", sessionId, ownerId); - } - - @Test(expected = IllegalArgumentException.class) - public void testConstructorWithNullSessionId() { - new AgenticMemoryAdapter(client, memoryContainerId, null, ownerId); - } - - @Test(expected = IllegalArgumentException.class) - public void testConstructorWithEmptySessionId() { - new AgenticMemoryAdapter(client, memoryContainerId, "", ownerId); - } - - @Test(expected = IllegalArgumentException.class) - public void testConstructorWithNullOwnerId() { - new AgenticMemoryAdapter(client, memoryContainerId, sessionId, null); - } - - @Test(expected = IllegalArgumentException.class) - public void testConstructorWithEmptyOwnerId() { - new AgenticMemoryAdapter(client, memoryContainerId, sessionId, ""); - } - - @Test - public void testGetConversationId() { - assertEquals(sessionId, adapter.getConversationId()); - } - - @Test - public void testGetMemoryContainerId() { - assertEquals(memoryContainerId, adapter.getMemoryContainerId()); - } - - @Test(expected = IllegalArgumentException.class) - public void testSaveTraceDataWithNullToolName() { - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - adapter.saveTraceData(null, "input", "output", "parent-id", 1, "action", listener); - } - - @Test(expected = IllegalArgumentException.class) - public void testSaveTraceDataWithEmptyToolName() { - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - adapter.saveTraceData("", "input", "output", "parent-id", 1, "action", listener); - } - - @Test(expected = IllegalArgumentException.class) - public void testSaveTraceDataWithNullListener() { - adapter.saveTraceData("tool", "input", "output", "parent-id", 1, "action", null); - } - - @Test(expected = IllegalArgumentException.class) - public void testSaveInteractionWithNullListener() { - adapter.saveInteraction("question", "response", null, 1, "action", null); - } - - @Test - public void testUpdateInteractionWithNullInteractionId() { - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - Map updateFields = new HashMap<>(); - updateFields.put("response", "updated response"); - - adapter.updateInteraction(null, updateFields, listener); - - // Verify that onFailure was called with IllegalArgumentException - verify(listener).onFailure(any(IllegalArgumentException.class)); - } - - @Test - public void testUpdateInteractionWithEmptyInteractionId() { - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - Map updateFields = new HashMap<>(); - updateFields.put("response", "updated response"); - - adapter.updateInteraction("", updateFields, listener); - - // Verify that onFailure was called with IllegalArgumentException - verify(listener).onFailure(any(IllegalArgumentException.class)); - } - - @Test - public void testUpdateInteractionWithNullUpdateFields() { - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - - adapter.updateInteraction("interaction-id", null, listener); - - // Verify that onFailure was called with IllegalArgumentException - verify(listener).onFailure(any(IllegalArgumentException.class)); - } - - @Test - public void testUpdateInteractionWithEmptyUpdateFields() { - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - Map updateFields = new HashMap<>(); - - adapter.updateInteraction("interaction-id", updateFields, listener); - - // Verify that onFailure was called with IllegalArgumentException - verify(listener).onFailure(any(IllegalArgumentException.class)); - } - - @Test(expected = IllegalArgumentException.class) - public void testUpdateInteractionWithNullListener() { - Map updateFields = new HashMap<>(); - updateFields.put("response", "updated response"); - - adapter.updateInteraction("interaction-id", updateFields, null); - } -} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapterTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapterTest.java deleted file mode 100644 index 990598dd42..0000000000 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/ChatMemoryAdapterTest.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.engine.algorithms.agent; - -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.mock; - -import java.util.List; - -import org.junit.Test; -import org.opensearch.core.action.ActionListener; - -/** - * Unit tests for ChatMemoryAdapter interface default methods. - */ -public class ChatMemoryAdapterTest { - - /** - * Test implementation of ChatMemoryAdapter for testing default methods - */ - private static class TestChatMemoryAdapter implements ChatMemoryAdapter { - @Override - public void getMessages(ActionListener> listener) { - // Test implementation - not used in these tests - } - - @Override - public String getConversationId() { - return "test-conversation-id"; - } - - @Override - public String getMemoryContainerId() { - return "test-memory-container-id"; - } - } - - @Test - public void testSaveInteractionDefaultImplementation() { - TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - - // Test that default implementation throws UnsupportedOperationException - adapter.saveInteraction("question", "response", "parentId", 1, "action", listener); - - // Verify that onFailure was called with UnsupportedOperationException - org.mockito.Mockito - .verify(listener) - .onFailure( - org.mockito.ArgumentMatchers - .argThat( - exception -> exception instanceof UnsupportedOperationException - && "Save not implemented".equals(exception.getMessage()) - ) - ); - } - - @Test - public void testUpdateInteractionDefaultImplementation() { - TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - - // Test that default implementation throws UnsupportedOperationException - adapter.updateInteraction("interactionId", java.util.Map.of("key", "value"), listener); - - // Verify that onFailure was called with UnsupportedOperationException - org.mockito.Mockito - .verify(listener) - .onFailure( - org.mockito.ArgumentMatchers - .argThat( - exception -> exception instanceof UnsupportedOperationException - && "Update interaction not implemented".equals(exception.getMessage()) - ) - ); - } - - @Test - public void testSaveTraceDataDefaultImplementation() { - TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); - @SuppressWarnings("unchecked") - ActionListener listener = mock(ActionListener.class); - - // Test that default implementation throws UnsupportedOperationException - adapter.saveTraceData("toolName", "input", "output", "parentId", 1, "action", listener); - - // Verify that onFailure was called with UnsupportedOperationException - org.mockito.Mockito - .verify(listener) - .onFailure( - org.mockito.ArgumentMatchers - .argThat( - exception -> exception instanceof UnsupportedOperationException - && "Save trace data not implemented".equals(exception.getMessage()) - ) - ); - } - - @Test - public void testGetConversationId() { - TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); - assertEquals("test-conversation-id", adapter.getConversationId()); - } - - @Test - public void testGetMemoryContainerId() { - TestChatMemoryAdapter adapter = new TestChatMemoryAdapter(); - assertEquals("test-memory-container-id", adapter.getMemoryContainerId()); - } -} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 57c472a4c4..f6c3e3618e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1171,136 +1171,4 @@ public void testConstructLLMParams_DefaultValues() { Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION)); Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE)); } - - @Test - public void testCreateMemoryAdapter_ConversationIndex() { - // Test that ConversationIndex memory type returns ConversationIndexMemory - LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); - MLMemorySpec memorySpec = MLMemorySpec.builder().type("conversation_index").build(); - MLAgent mlAgent = MLAgent - .builder() - .name("test_agent") - .type(MLAgentType.CONVERSATIONAL.name()) - .llm(llmSpec) - .memory(memorySpec) - .build(); - - Map params = new HashMap<>(); - params.put(MLAgentExecutor.QUESTION, "test question"); - params.put(MLAgentExecutor.MEMORY_ID, "test_memory_id"); - - // Mock the memory factory - when(memoryMap.get("conversation_index")).thenReturn(memoryFactory); - - // Create a mock ConversationIndexMemory - org.opensearch.ml.engine.memory.ConversationIndexMemory mockMemory = Mockito - .mock(org.opensearch.ml.engine.memory.ConversationIndexMemory.class); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(mockMemory); - return null; - }).when(memoryFactory).create(anyString(), anyString(), anyString(), any()); - - // Test the createMemoryAdapter method - ActionListener testListener = new ActionListener() { - @Override - public void onResponse(Object result) { - // Verify that we get back a ConversationIndexMemory - assertTrue("Expected ConversationIndexMemory", result instanceof org.opensearch.ml.engine.memory.ConversationIndexMemory); - assertEquals("Memory should be the mocked instance", mockMemory, result); - } - - @Override - public void onFailure(Exception e) { - Assert.fail("Should not fail: " + e.getMessage()); - } - }; - - // This would normally be a private method call, but for testing we can verify the logic - // by checking that the correct memory type handling works through the public run method - // The actual test would need to be done through integration testing - } - - @Test - public void testCreateMemoryAdapter_AgenticMemory() { - // Test that agentic memory type returns AgenticMemoryAdapter - LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); - MLMemorySpec memorySpec = MLMemorySpec.builder().type("agentic_memory").build(); - MLAgent mlAgent = MLAgent - .builder() - .name("test_agent") - .type(MLAgentType.CONVERSATIONAL.name()) - .llm(llmSpec) - .memory(memorySpec) - .build(); - - Map params = new HashMap<>(); - params.put("memory_container_id", "test_container_id"); - params.put("session_id", "test_session_id"); - params.put("owner_id", "test_owner_id"); - - // This test verifies that the agentic memory path would be taken - // Full integration testing would require mocking the agentic memory services - assertNotNull("MLAgent should be created successfully", mlAgent); - assertEquals("Memory type should be agentic_memory", "agentic_memory", mlAgent.getMemory().getType()); - } - - @Test - public void testEnhancedChatMessage() { - // Test the enhanced ChatMessage format - ChatMessage userMessage = ChatMessage - .builder() - .id("msg_1") - .timestamp(java.time.Instant.now()) - .sessionId("session_123") - .role("user") - .content("Hello, how are you?") - .contentType("text") - .origin("agentic_memory") - .metadata(Map.of("confidence", 0.95)) - .build(); - - ChatMessage assistantMessage = ChatMessage - .builder() - .id("msg_2") - .timestamp(java.time.Instant.now()) - .sessionId("session_123") - .role("assistant") - .content("I'm doing well, thank you!") - .contentType("text") - .origin("agentic_memory") - .metadata(Map.of("confidence", 0.98)) - .build(); - - // Verify the enhanced ChatMessage structure - assertEquals("user", userMessage.getRole()); - assertEquals("text", userMessage.getContentType()); - assertEquals("agentic_memory", userMessage.getOrigin()); - assertNotNull(userMessage.getMetadata()); - assertEquals(0.95, userMessage.getMetadata().get("confidence")); - - assertEquals("assistant", assistantMessage.getRole()); - assertEquals("I'm doing well, thank you!", assistantMessage.getContent()); - } - - @Test - public void testSimpleChatHistoryTemplateEngine() { - // Test the new template engine - SimpleChatHistoryTemplateEngine templateEngine = new SimpleChatHistoryTemplateEngine(); - - List messages = List - .of( - ChatMessage.builder().role("user").content("What's the weather?").contentType("text").build(), - ChatMessage.builder().role("assistant").content("It's sunny today!").contentType("text").build(), - ChatMessage.builder().role("system").content("Weather data retrieved from API").contentType("context").build() - ); - - String chatHistory = templateEngine.buildSimpleChatHistory(messages); - - assertNotNull("Chat history should not be null", chatHistory); - assertTrue("Should contain user message", chatHistory.contains("Human: What's the weather?")); - assertTrue("Should contain assistant message", chatHistory.contains("Assistant: It's sunny today!")); - assertTrue("Should contain system context", chatHistory.contains("[Context] Weather data retrieved from API")); - } } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java index e935ae23ea..feb14a6e3d 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java @@ -708,9 +708,8 @@ public SearchSourceBuilder addUserBackendRolesFilter(User user, SearchSourceBuil public SearchSourceBuilder addOwnerIdFilter(User user, SearchSourceBuilder searchSourceBuilder) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - if (user != null) { - boolQueryBuilder.should(QueryBuilders.termsQuery(OWNER_ID_FIELD, user.getName())); - } + boolQueryBuilder.should(QueryBuilders.termsQuery(OWNER_ID_FIELD, user.getName())); + return applyFilterToSearchSource(searchSourceBuilder, boolQueryBuilder); } From 10d822b79b692b29a80b4bb1738a8596820ada14 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 24 Oct 2025 16:02:10 -0700 Subject: [PATCH 09/58] refactor memory interface Signed-off-by: Yaliang Wu --- .../ml/common/conversation/Interaction.java | 12 +- .../opensearch/ml/common/memory/Memory.java | 57 ++++++ .../opensearch/ml/common}/memory/Message.java | 6 +- .../algorithms/agent/MLAgentExecutor.java | 137 +++++++------- .../algorithms/agent/MLChatAgentRunner.java | 70 ++++---- .../MLConversationalFlowAgentRunner.java | 39 ++-- .../algorithms/agent/MLFlowAgentRunner.java | 39 ++-- .../MLPlanExecuteAndReflectAgentRunner.java | 54 +++--- .../memory/AgenticConversationMemory.java | 169 ++++++++++++++++++ .../ml/engine/memory/AgenticMemoryConfig.java | 45 +++++ .../ml/engine/memory/BaseMessage.java | 2 +- .../memory/ConversationIndexMemory.java | 132 +++++++------- .../algorithms/agent/MLAgentExecutorTest.java | 2 +- .../agent/MLChatAgentRunnerTest.java | 144 ++++++++++++++- .../agent/MLFlowAgentRunnerTest.java | 56 +++--- ...LPlanExecuteAndReflectAgentRunnerTest.java | 6 +- .../memory/AgenticConversationMemoryTest.java | 156 ++++++++++++++++ .../memory/ConversationIndexMemoryTest.java | 169 ++++++++---------- .../ml/plugin/MachineLearningPlugin.java | 7 +- .../ml/common/spi/memory/Memory.java | 64 ------- 20 files changed, 925 insertions(+), 441 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/memory/Memory.java rename {spi/src/main/java/org/opensearch/ml/common/spi => common/src/main/java/org/opensearch/ml/common}/memory/Message.java (74%) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticMemoryConfig.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java delete mode 100644 spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java index 5da68b0d07..19c6ee21df 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -28,6 +28,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.memory.Message; import org.opensearch.search.SearchHit; import lombok.AllArgsConstructor; @@ -39,7 +40,7 @@ */ @Builder @AllArgsConstructor -public class Interaction implements Writeable, ToXContentObject { +public class Interaction implements Writeable, ToXContentObject, Message { @Getter private String id; @@ -275,4 +276,13 @@ public String toString() { + "}"; } + @Override + public String getType() { + return ""; + } + + @Override + public String getContent() { + return ""; + } } diff --git a/common/src/main/java/org/opensearch/ml/common/memory/Memory.java b/common/src/main/java/org/opensearch/ml/common/memory/Memory.java new file mode 100644 index 0000000000..9cd18deeae --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/memory/Memory.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.memory; + +import java.util.List; +import java.util.Map; + +import org.opensearch.core.action.ActionListener; + +/** + * A general memory interface. + * @param Message type + * @param Save response type + * @param Update response type + */ +public interface Memory { + + /** + * Get memory type. + * @return memory type + */ + String getType(); + + /** + * Get memory ID. + * @return memory ID + */ + String getId(); + + default void save(Message message, String parentId, Integer traceNum, String action) {} + + default void save(Message message, String parentId, Integer traceNum, String action, ActionListener listener) {} + + default void update(String messageId, Map updateContent, ActionListener updateListener) {} + + default void getMessages(int size, ActionListener> listener) {} + + /** + * Clear all memory. + */ + void clear(); + + void deleteInteractionAndTrace(String regenerateInteractionId, ActionListener wrap); + + interface Factory { + /** + * Create an instance of this Memory. + * + * @param params Parameters for the memory + * @param listener Action listener for the memory creation action + */ + void create(Map params, ActionListener listener); + } +} diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Message.java b/common/src/main/java/org/opensearch/ml/common/memory/Message.java similarity index 74% rename from spi/src/main/java/org/opensearch/ml/common/spi/memory/Message.java rename to common/src/main/java/org/opensearch/ml/common/memory/Message.java index 148cc769e3..d7ca18718c 100644 --- a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Message.java +++ b/common/src/main/java/org/opensearch/ml/common/memory/Message.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.common.spi.memory; +package org.opensearch.ml.common.memory; /** * General message interface. @@ -12,13 +12,13 @@ public interface Message { /** * Get message type. - * @return + * @return message type */ String getType(); /** * Get message content. - * @return + * @return message content */ String getContent(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 1594506cf4..9c0ad8a3b2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -16,6 +16,7 @@ import static org.opensearch.ml.common.output.model.ModelTensorOutput.INFERENCE_RESULT_FIELD; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE; import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; import java.security.AccessController; import java.security.PrivilegedActionException; @@ -54,6 +55,7 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.MLTaskOutput; import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.output.model.ModelTensor; @@ -61,7 +63,6 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.settings.SettingsChangeListener; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.Executable; import org.opensearch.ml.engine.annotation.Function; @@ -245,55 +246,58 @@ public void execute(Input input, ActionListener listener, TransportChann && memorySpec.getType() != null && memoryFactoryMap.containsKey(memorySpec.getType()) && (memoryId == null || parentInteractionId == null)) { - ConversationIndexMemory.Factory conversationIndexMemoryFactory = - (ConversationIndexMemory.Factory) memoryFactoryMap.get(memorySpec.getType()); - conversationIndexMemoryFactory - .create(question, memoryId, appType, ActionListener.wrap(memory -> { - inputDataSet.getParameters().put(MEMORY_ID, memory.getConversationId()); - // get question for regenerate - if (regenerateInteractionId != null) { - log.info("Regenerate for existing interaction {}", regenerateInteractionId); - client - .execute( - GetInteractionAction.INSTANCE, - new GetInteractionRequest(regenerateInteractionId), - ActionListener.wrap(interactionRes -> { - inputDataSet - .getParameters() - .putIfAbsent(QUESTION, interactionRes.getInteraction().getInput()); - saveRootInteractionAndExecute( - listener, - memory, - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - channel - ); - }, e -> { - log.error("Failed to get existing interaction for regeneration", e); - listener.onFailure(e); - }) - ); - } else { - saveRootInteractionAndExecute( - listener, - memory, - inputDataSet, - mlTask, - isAsync, - outputs, - modelTensors, - mlAgent, - channel + Memory.Factory> memoryFactory = memoryFactoryMap.get(memorySpec.getType()); + + Map params = new HashMap<>(); + params.put(ConversationIndexMemory.MEMORY_NAME, question); + params.put(ConversationIndexMemory.MEMORY_ID, memoryId); + params.put(APP_TYPE, appType); + memoryFactory.create(params, ActionListener.wrap(memory -> { + inputDataSet.getParameters().put(MEMORY_ID, memory.getId()); + // get question for regenerate + if (regenerateInteractionId != null) { + log.info("Regenerate for existing interaction {}", regenerateInteractionId); + client + .execute( + GetInteractionAction.INSTANCE, + new GetInteractionRequest(regenerateInteractionId), + ActionListener.wrap(interactionRes -> { + inputDataSet + .getParameters() + .putIfAbsent(QUESTION, interactionRes.getInteraction().getInput()); + saveRootInteractionAndExecute( + listener, + memory, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel + ); + }, e -> { + log.error("Failed to get existing interaction for regeneration", e); + listener.onFailure(e); + }) ); - } - }, ex -> { - log.error("Failed to read conversation memory", ex); - listener.onFailure(ex); - })); + } else { + saveRootInteractionAndExecute( + listener, + memory, + inputDataSet, + mlTask, + isAsync, + outputs, + modelTensors, + mlAgent, + channel + ); + } + }, ex -> { + log.error("Failed to read conversation memory", ex); + listener.onFailure(ex); + })); } else { // For existing conversations, create memory instance using factory if (memorySpec != null && memorySpec.getType() != null) { @@ -304,9 +308,7 @@ public void execute(Input input, ActionListener listener, TransportChann // be null factory .create( - null, - memoryId, - appType, + Map.of(MEMORY_ID, memoryId, APP_TYPE, appType), ActionListener .wrap( createdMemory -> executeAgent( @@ -377,7 +379,7 @@ public void execute(Input input, ActionListener listener, TransportChann */ private void saveRootInteractionAndExecute( ActionListener listener, - ConversationIndexMemory memory, + Memory memory, RemoteInferenceInputDataSet inputDataSet, MLTask mlTask, boolean isAsync, @@ -396,7 +398,7 @@ private void saveRootInteractionAndExecute( .question(question) .response("") .finalAnswer(true) - .sessionId(memory.getConversationId()) + .sessionId(memory.getId()) .build(); memory.save(msg, null, null, null, ActionListener.wrap(interaction -> { log.info("Created parent interaction ID: {}", interaction.getId()); @@ -404,7 +406,6 @@ private void saveRootInteractionAndExecute( // only delete previous interaction when new interaction created if (regenerateInteractionId != null) { memory - .getMemoryManager() .deleteInteractionAndTrace( regenerateInteractionId, ActionListener @@ -413,7 +414,7 @@ private void saveRootInteractionAndExecute( inputDataSet, mlTask, isAsync, - memory.getConversationId(), + memory.getId(), mlAgent, outputs, modelTensors, @@ -428,18 +429,7 @@ private void saveRootInteractionAndExecute( ) ); } else { - executeAgent( - inputDataSet, - mlTask, - isAsync, - memory.getConversationId(), - mlAgent, - outputs, - modelTensors, - listener, - memory, - channel - ); + executeAgent(inputDataSet, mlTask, isAsync, memory.getId(), mlAgent, outputs, modelTensors, listener, memory, channel); } }, ex -> { log.error("Failed to create parent interaction", ex); @@ -456,7 +446,7 @@ private void executeAgent( List outputs, List modelTensors, ActionListener listener, - ConversationIndexMemory memory, + Memory memory, TransportChannel channel ) { String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null; @@ -535,7 +525,7 @@ private ActionListener createAgentActionListener( List modelTensors, String agentType, String parentInteractionId, - ConversationIndexMemory memory + Memory memory ) { return ActionListener.wrap(output -> { if (output != null) { @@ -556,7 +546,7 @@ private ActionListener createAsyncTaskUpdater( List outputs, List modelTensors, String parentInteractionId, - ConversationIndexMemory memory + Memory memory ) { String taskId = mlTask.getTaskId(); Map agentResponse = new HashMap<>(); @@ -711,15 +701,14 @@ public void indexMLTask(MLTask mlTask, ActionListener listener) { } } - private void updateInteractionWithFailure(String interactionId, ConversationIndexMemory memory, String errorMessage) { + private void updateInteractionWithFailure(String interactionId, Memory memory, String errorMessage) { if (interactionId != null && memory != null) { String failureMessage = "Agent execution failed: " + errorMessage; Map updateContent = new HashMap<>(); updateContent.put(RESPONSE_FIELD, failureMessage); memory - .getMemoryManager() - .updateInteraction( + .update( interactionId, updateContent, ActionListener diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 7e1a4050bd..dad28a44f0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; +import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.processTextDoc; import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput; @@ -34,6 +35,8 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; import static org.opensearch.ml.engine.tools.ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY; import java.security.PrivilegedActionException; @@ -61,11 +64,11 @@ import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.memory.Memory; +import org.opensearch.ml.common.memory.Message; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.memory.Memory; -import org.opensearch.ml.common.spi.memory.Message; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.utils.StringUtils; @@ -73,7 +76,6 @@ import org.opensearch.ml.engine.function_calling.FunctionCalling; import org.opensearch.ml.engine.function_calling.FunctionCallingFactory; import org.opensearch.ml.engine.function_calling.LLMMessage; -import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; @@ -186,10 +188,11 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener String chatHistoryResponseTemplate = params.get(CHAT_HISTORY_RESPONSE_TEMPLATE); int messageHistoryLimit = getMessageHistoryLimit(params); - ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap.get(memoryType); - conversationIndexMemoryFactory.create(title, memoryId, appType, ActionListener.wrap(memory -> { + Memory.Factory> factory = memoryFactoryMap.get(memoryType); + + factory.create(Map.of(MEMORY_ID, memoryId, MEMORY_NAME, title, APP_TYPE, appType), ActionListener.wrap(memory -> { // TODO: call runAgent directly if messageHistoryLimit == 0 - memory.getMessages(ActionListener.>wrap(r -> { + memory.getMessages(messageHistoryLimit, ActionListener.>wrap(r -> { List messageList = new ArrayList<>(); for (Interaction next : r) { String question = next.getInput(); @@ -203,7 +206,7 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener .add( ConversationIndexMessage .conversationIndexMessageBuilder() - .sessionId(memory.getConversationId()) + .sessionId(memory.getId()) .question(question) .response(response) .build() @@ -244,11 +247,11 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener } } - runAgent(mlAgent, params, listener, memory, memory.getConversationId(), functionCalling); + runAgent(mlAgent, params, listener, memory, memory.getId(), functionCalling); }, e -> { log.error("Failed to get chat history", e); listener.onFailure(e); - }), messageHistoryLimit); + })); }, listener::onFailure)); } @@ -302,8 +305,6 @@ private void runReAct( boolean traceDisabled = tmpParameters.containsKey(DISABLE_TRACE) && Boolean.parseBoolean(tmpParameters.get(DISABLE_TRACE)); // Create root interaction. - ConversationIndexMemory conversationIndexMemory = (ConversationIndexMemory) memory; - // Trace number AtomicInteger traceNumber = new AtomicInteger(0); @@ -364,7 +365,7 @@ private void runReAct( verbose, traceDisabled, traceTensors, - conversationIndexMemory, + memory, traceNumber, additionalInfo, finalAnswer @@ -388,7 +389,7 @@ private void runReAct( ); saveTraceData( - conversationIndexMemory, + memory, memory.getType(), question, thoughtResponse, @@ -408,7 +409,7 @@ private void runReAct( verbose, traceDisabled, traceTensors, - conversationIndexMemory, + memory, traceNumber, additionalInfo, lastThought, @@ -467,7 +468,7 @@ private void runReAct( scratchpadBuilder.append(toolResponse).append("\n\n"); saveTraceData( - conversationIndexMemory, + memory, "ReAct", lastActionInput.get(), outputToOutputString(filteredOutput), @@ -509,7 +510,7 @@ private void runReAct( verbose, traceDisabled, traceTensors, - conversationIndexMemory, + memory, traceNumber, additionalInfo, lastThought, @@ -680,8 +681,8 @@ private static void updateParametersAcrossTools(Map tmpParameter } public static void saveTraceData( - ConversationIndexMemory conversationIndexMemory, - String memory, + Memory memory, + String memoryType, String question, String thoughtResponse, String sessionId, @@ -690,17 +691,17 @@ public static void saveTraceData( AtomicInteger traceNumber, String origin ) { - if (conversationIndexMemory != null) { + if (memory != null) { ConversationIndexMessage msgTemp = ConversationIndexMessage .conversationIndexMessageBuilder() - .type(memory) + .type(memoryType) .question(question) .response(thoughtResponse) .finalAnswer(false) .sessionId(sessionId) .build(); if (!traceDisabled) { - conversationIndexMemory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), origin); + memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), origin); } } } @@ -713,7 +714,7 @@ private void sendFinalAnswer( boolean verbose, boolean traceDisabled, List cotModelTensors, - ConversationIndexMemory conversationIndexMemory, + Memory memory, AtomicInteger traceNumber, Map additionalInfo, String finalAnswer @@ -721,12 +722,11 @@ private void sendFinalAnswer( // Send completion chunk for streaming streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId); - if (conversationIndexMemory != null) { + if (memory != null) { String copyOfFinalAnswer = finalAnswer; ActionListener saveTraceListener = ActionListener.wrap(r -> { - conversationIndexMemory - .getMemoryManager() - .updateInteraction( + memory + .update( parentInteractionId, Map.of(AI_RESPONSE_FIELD, copyOfFinalAnswer, ADDITIONAL_INFO_FIELD, additionalInfo), ActionListener.wrap(res -> { @@ -742,17 +742,7 @@ private void sendFinalAnswer( }, e -> { listener.onFailure(e); }) ); }, e -> { listener.onFailure(e); }); - saveMessage( - conversationIndexMemory, - question, - finalAnswer, - sessionId, - parentInteractionId, - traceNumber, - true, - traceDisabled, - saveTraceListener - ); + saveMessage(memory, question, finalAnswer, sessionId, parentInteractionId, traceNumber, true, traceDisabled, saveTraceListener); } else { streamingWrapper .sendFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, additionalInfo, finalAnswer); @@ -882,7 +872,7 @@ private void handleMaxIterationsReached( boolean verbose, boolean traceDisabled, List traceTensors, - ConversationIndexMemory conversationIndexMemory, + Memory memory, AtomicInteger traceNumber, Map additionalInfo, AtomicReference lastThought, @@ -900,7 +890,7 @@ private void handleMaxIterationsReached( verbose, traceDisabled, traceTensors, - conversationIndexMemory, + memory, traceNumber, additionalInfo, incompleteResponse @@ -909,7 +899,7 @@ private void handleMaxIterationsReached( } private void saveMessage( - ConversationIndexMemory memory, + Memory memory, String question, String finalAnswer, String sessionId, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index 54d847b929..86b46bbf56 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -19,9 +19,12 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; import java.security.PrivilegedActionException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -39,15 +42,14 @@ import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.memory.Memory; +import org.opensearch.ml.common.memory.Message; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.common.spi.memory.Memory; -import org.opensearch.ml.common.spi.memory.Message; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.remote.metadata.client.SdkClient; @@ -109,9 +111,15 @@ public void run(MLAgent mlAgent, Map params, ActionListener { - memory.getMessages(ActionListener.>wrap(r -> { + Memory.Factory> memoryFactory = memoryFactoryMap.get(memoryType); + + Map createMemoryParams = new HashMap<>(); + params.put(MEMORY_NAME, title); + params.put(MEMORY_ID, memoryId); + params.put(APP_TYPE, appType); + + memoryFactory.create(createMemoryParams, ActionListener.wrap(memory -> { + memory.getMessages(messageHistoryLimit, ActionListener.>wrap(r -> { List messageList = new ArrayList<>(); for (Interaction next : r) { String question = next.getInput(); @@ -125,7 +133,7 @@ public void run(MLAgent mlAgent, Map params, ActionListener params, ActionListener { log.error("Failed to get chat history", e); listener.onFailure(e); - }), messageHistoryLimit); + })); }, listener::onFailure)); } @@ -153,7 +161,7 @@ private void runAgent( MLAgent mlAgent, Map params, ActionListener listener, - ConversationIndexMemory memory, + Memory memory, String memoryId, String parentInteractionId ) { @@ -244,7 +252,7 @@ private void runAgent( private void processOutput( Map params, ActionListener listener, - ConversationIndexMemory memory, + Memory memory, String memoryId, String parentInteractionId, List toolSpecs, @@ -357,7 +365,7 @@ private void runNextStep( private void saveMessage( Map params, - ConversationIndexMemory memory, + Memory memory, String outputResponse, String memoryId, String parentInteractionId, @@ -392,11 +400,10 @@ void updateMemoryWithListener( if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) { return; } - ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap - .get(memorySpec.getType()); - conversationIndexMemoryFactory + Memory.Factory factory = memoryFactoryMap.get(memorySpec.getType()); + factory .create( - memoryId, + Map.of(MEMORY_ID, memoryId), ActionListener .wrap( memory -> memory.update(interactionId, Map.of(ActionConstants.ADDITIONAL_INFO_FIELD, additionalInfo), listener), diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index 30725a8c47..7c8742a570 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -12,6 +12,7 @@ import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; import java.util.ArrayList; import java.util.List; @@ -28,9 +29,9 @@ import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.common.utils.ToolUtils; @@ -169,23 +170,23 @@ public void run(MLAgent mlAgent, Map params, ActionListener additionalInfo, MLMemorySpec memorySpec, String memoryId, String interactionId) { - if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) { - return; - } - ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap - .get(memorySpec.getType()); - conversationIndexMemoryFactory - .create( - memoryId, - ActionListener - .wrap( - memory -> updateInteraction(additionalInfo, interactionId, memory), - e -> log.error("Failed create memory from id: {}", memoryId, e) - ) - ); - } + // @VisibleForTesting + // void updateMemory(Map additionalInfo, MLMemorySpec memorySpec, String memoryId, String interactionId) { + // if (memoryId == null || interactionId == null || memorySpec == null || memorySpec.getType() == null) { + // return; + // } + // ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) memoryFactoryMap + // .get(memorySpec.getType()); + // conversationIndexMemoryFactory + // .create( + // memoryId, + // ActionListener + // .wrap( + // memory -> updateInteraction(additionalInfo, interactionId, memory), + // e -> log.error("Failed create memory from id: {}", memoryId, e) + // ) + // ); + // } @VisibleForTesting void updateMemoryWithListener( @@ -202,7 +203,7 @@ void updateMemoryWithListener( .get(memorySpec.getType()); conversationIndexMemoryFactory .create( - memoryId, + Map.of(MEMORY_ID, memoryId), ActionListener .wrap( memory -> updateInteractionWithListener(additionalInfo, interactionId, memory, listener), diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index b8b89d8aa2..96d314503c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -32,6 +32,9 @@ import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.FINAL_RESULT_RESPONSE_INSTRUCTIONS; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLANNER_RESPONSIBILITY; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; import java.util.ArrayList; import java.util.HashMap; @@ -57,10 +60,10 @@ import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; @@ -289,32 +292,35 @@ public void run(MLAgent mlAgent, Map apiParams, ActionListenerwrap(memory -> { - memory.getMessages(ActionListener.>wrap(interactions -> { - List completedSteps = new ArrayList<>(); - for (Interaction interaction : interactions) { - String question = interaction.getInput(); - String response = interaction.getResponse(); - - if (Strings.isNullOrEmpty(response)) { - continue; + .create( + Map.of(MEMORY_ID, memoryId, MEMORY_NAME, apiParams.get(USER_PROMPT_FIELD), APP_TYPE, appType), + ActionListener.wrap(memory -> { + memory.getMessages(messageHistoryLimit, ActionListener.>wrap(interactions -> { + List completedSteps = new ArrayList<>(); + for (Interaction interaction : interactions) { + String question = interaction.getInput(); + String response = interaction.getResponse(); + + if (Strings.isNullOrEmpty(response)) { + continue; + } + + completedSteps.add(question); + completedSteps.add(response); } - completedSteps.add(question); - completedSteps.add(response); - } - - if (!completedSteps.isEmpty()) { - addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD); - usePlannerWithHistoryPromptTemplate(allParams); - } + if (!completedSteps.isEmpty()) { + addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD); + usePlannerWithHistoryPromptTemplate(allParams); + } - setToolsAndRunAgent(mlAgent, allParams, completedSteps, memory, memory.getConversationId(), listener); - }, e -> { - log.error("Failed to get chat history", e); - listener.onFailure(e); - }), messageHistoryLimit); - }, listener::onFailure)); + setToolsAndRunAgent(mlAgent, allParams, completedSteps, memory, memory.getConversationId(), listener); + }, e -> { + log.error("Failed to get chat history", e); + listener.onFailure(e); + })); + }, listener::onFailure) + ); } private void setToolsAndRunAgent( diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java new file mode 100644 index 0000000000..7cd2484078 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java @@ -0,0 +1,169 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; + +import java.util.List; +import java.util.Map; + +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.ml.common.memory.Memory; +import org.opensearch.ml.common.memory.Message; +import org.opensearch.ml.common.transport.session.MLCreateSessionAction; +import org.opensearch.ml.common.transport.session.MLCreateSessionInput; +import org.opensearch.ml.common.transport.session.MLCreateSessionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.transport.client.Client; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; + +/** + * Agentic Memory implementation that stores conversations in Memory Container + * Uses TransportCreateSessionAction and TransportAddMemoriesAction for all operations + */ +@Log4j2 +@Getter +public class AgenticConversationMemory implements Memory { + + public static final String TYPE = "agentic_memory"; + + public AgenticConversationMemory(Client client, String memoryId, String memoryContainerId) {} + + @Override + public String getType() { + return ""; + } + + @Override + public String getId() { + return ""; + } + + @Override + public void save(Message message, String parentId, Integer traceNum, String action) { + Memory.super.save(message, parentId, traceNum, action); + } + + @Override + public void save( + Message message, + String parentId, + Integer traceNum, + String action, + ActionListener listener + ) { + Memory.super.save(message, parentId, traceNum, action, listener); + } + + @Override + public void update(String messageId, Map updateContent, ActionListener updateListener) { + Memory.super.update(messageId, updateContent, updateListener); + } + + @Override + public void getMessages(int size, ActionListener> listener) { + Memory.super.getMessages(size, listener); + } + + @Override + public void clear() { + + } + + @Override + public void deleteInteractionAndTrace(String regenerateInteractionId, ActionListener wrap) { + + } + + /** + * Factory for creating AgenticConversationMemory instances + */ + public static class Factory implements Memory.Factory { + private Client client; + + public void init(Client client) { + this.client = client; + } + + @Override + public void create(Map map, ActionListener listener) { + if (map == null || map.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Invalid input parameter for creating AgenticConversationMemory")); + return; + } + + String memoryId = (String) map.get(MEMORY_ID); + String name = (String) map.get(MEMORY_NAME); + String appType = (String) map.get(APP_TYPE); + String memoryContainerId = (String) map.get("memory_container_id"); + + create(name, memoryId, appType, memoryContainerId, listener); + } + + public void create( + String name, + String memoryId, + String appType, + String memoryContainerId, + ActionListener listener + ) { + // Memory container ID is required for AgenticConversationMemory + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener + .onFailure( + new IllegalArgumentException( + "Memory container ID is required for AgenticConversationMemory. " + + "Please provide 'memory_container_id' in the agent configuration." + ) + ); + return; + } + + if (Strings.isEmpty(memoryId)) { + // Create new session using TransportCreateSessionAction + createSessionInMemoryContainer(name, memoryContainerId, ActionListener.wrap(sessionId -> { + create(sessionId, memoryContainerId, listener); + log.debug("Created session in memory container, session id: {}", sessionId); + }, e -> { + log.error("Failed to create session in memory container", e); + listener.onFailure(e); + })); + } else { + // Use existing session/memory ID + create(memoryId, memoryContainerId, listener); + } + } + + /** + * Create a new session in the memory container using the new session API + */ + private void createSessionInMemoryContainer(String summary, String memoryContainerId, ActionListener listener) { + MLCreateSessionInput input = MLCreateSessionInput.builder().memoryContainerId(memoryContainerId).summary(summary).build(); + + MLCreateSessionRequest request = MLCreateSessionRequest.builder().mlCreateSessionInput(input).build(); + + client + .execute( + MLCreateSessionAction.INSTANCE, + request, + ActionListener.wrap(response -> { listener.onResponse(response.getSessionId()); }, e -> { + log.error("Failed to create session via TransportCreateSessionAction", e); + listener.onFailure(e); + }) + ); + } + + public void create(String memoryId, String memoryContainerId, ActionListener listener) { + listener.onResponse(new AgenticConversationMemory(client, memoryId, memoryContainerId)); + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticMemoryConfig.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticMemoryConfig.java new file mode 100644 index 0000000000..987d947567 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticMemoryConfig.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import lombok.Builder; +import lombok.Data; + +/** + * Configuration for Agentic Memory integration + */ +@Data +@Builder +public class AgenticMemoryConfig { + + /** + * Memory container ID to use for storing conversations + */ + private String memoryContainerId; + + /** + * Whether to enable memory container integration + * If false, falls back to ConversationIndexMemory behavior + */ + @Builder.Default + private boolean enabled = true; + + /** + * Whether to enable inference (long-term memory extraction) + */ + @Builder.Default + private boolean enableInference = true; + + /** + * Custom namespace fields to add to memory container entries + */ + private java.util.Map customNamespace; + + /** + * Custom tags to add to memory container entries + */ + private java.util.Map customTags; +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java index 05b3185a34..562e425375 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/BaseMessage.java @@ -9,7 +9,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.spi.memory.Message; +import org.opensearch.ml.common.memory.Message; import lombok.Builder; import lombok.Getter; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java index 9720661eeb..e8e5f87a5a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java @@ -5,30 +5,19 @@ package org.opensearch.ml.engine.memory; -import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; import java.util.Map; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.search.SearchRequest; import org.opensearch.action.update.UpdateResponse; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.TermQueryBuilder; -import org.opensearch.ml.common.spi.memory.Memory; -import org.opensearch.ml.common.spi.memory.Message; +import org.opensearch.ml.common.memory.Memory; +import org.opensearch.ml.common.memory.Message; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.sort.SortOrder; import org.opensearch.transport.client.Client; import lombok.Getter; @@ -36,7 +25,7 @@ @Log4j2 @Getter -public class ConversationIndexMemory implements Memory { +public class ConversationIndexMemory implements Memory { public static final String TYPE = "conversation_index"; public static final String CONVERSATION_ID = "conversation_id"; public static final String FINAL_ANSWER = "final_answer"; @@ -75,28 +64,34 @@ public String getType() { } @Override - public void save(String id, Message message) { - this.save(id, message, ActionListener.wrap(r -> { log.info("saved message into {} memory, session id: {}", TYPE, id); }, e -> { - log.error("Failed to save message to memory", e); - })); + public String getId() { + return this.conversationId; } - @Override - public void save(String id, Message message, ActionListener listener) { - mlIndicesHandler.initMemoryMessageIndex(ActionListener.wrap(created -> { - if (created) { - IndexRequest indexRequest = new IndexRequest(memoryMessageIndexName).setRefreshPolicy(IMMEDIATE); - ConversationIndexMessage conversationIndexMessage = (ConversationIndexMessage) message; - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - conversationIndexMessage.toXContent(builder, ToXContent.EMPTY_PARAMS); - indexRequest.source(builder); - client.index(indexRequest, listener); - } else { - listener.onFailure(new RuntimeException("Failed to create memory message index")); - } - }, e -> { listener.onFailure(new RuntimeException("Failed to create memory message index", e)); })); - } + // @Override + // public void save(String id, Message message) { + // this.save(id, message, ActionListener.wrap(r -> { log.info("saved message into {} memory, session id: {}", TYPE, id); }, e -> { + // log.error("Failed to save message to memory", e); + // })); + // } + + // @Override + // public void save(String id, Message message, ActionListener listener) { + // mlIndicesHandler.initMemoryMessageIndex(ActionListener.wrap(created -> { + // if (created) { + // IndexRequest indexRequest = new IndexRequest(memoryMessageIndexName).setRefreshPolicy(IMMEDIATE); + // ConversationIndexMessage conversationIndexMessage = (ConversationIndexMessage) message; + // XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + // conversationIndexMessage.toXContent(builder, ToXContent.EMPTY_PARAMS); + // indexRequest.source(builder); + // client.index(indexRequest, listener); + // } else { + // listener.onFailure(new RuntimeException("Failed to create memory message index")); + // } + // }, e -> { listener.onFailure(new RuntimeException("Failed to create memory message index", e)); })); + // } + @Override public void save(Message message, String parentId, Integer traceNum, String action) { this.save(message, parentId, traceNum, action, ActionListener.wrap(r -> { log @@ -110,39 +105,43 @@ public void save(Message message, String parentId, Integer traceNum, String acti }, e -> { log.error("Failed to save interaction", e); })); } - public void save(Message message, String parentId, Integer traceNum, String action, ActionListener listener) { + @Override + public void save( + Message message, + String parentId, + Integer traceNum, + String action, + ActionListener listener + ) { ConversationIndexMessage msg = (ConversationIndexMessage) message; memoryManager .createInteraction(conversationId, msg.getQuestion(), null, msg.getResponse(), action, null, parentId, traceNum, listener); } - @Override - public void getMessages(String id, ActionListener listener) { - SearchRequest searchRequest = new SearchRequest(); - searchRequest.indices(memoryMessageIndexName); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.size(10000); - QueryBuilder sessionIdQueryBuilder = new TermQueryBuilder(CONVERSATION_ID, id); - - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - boolQueryBuilder.must(sessionIdQueryBuilder); - - if (retrieveFinalAnswer) { - QueryBuilder finalAnswerQueryBuilder = new TermQueryBuilder(FINAL_ANSWER, true); - boolQueryBuilder.must(finalAnswerQueryBuilder); - } + // @Override + // public void getMessages(String id, ActionListener listener) { + // SearchRequest searchRequest = new SearchRequest(); + // searchRequest.indices(memoryMessageIndexName); + // SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + // sourceBuilder.size(10000); + // QueryBuilder sessionIdQueryBuilder = new TermQueryBuilder(CONVERSATION_ID, id); + // + // BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + // boolQueryBuilder.must(sessionIdQueryBuilder); + // + // if (retrieveFinalAnswer) { + // QueryBuilder finalAnswerQueryBuilder = new TermQueryBuilder(FINAL_ANSWER, true); + // boolQueryBuilder.must(finalAnswerQueryBuilder); + // } + // + // sourceBuilder.query(boolQueryBuilder); + // sourceBuilder.sort(CREATED_TIME, SortOrder.ASC); + // searchRequest.source(sourceBuilder); + // client.search(searchRequest, listener); + // } - sourceBuilder.query(boolQueryBuilder); - sourceBuilder.sort(CREATED_TIME, SortOrder.ASC); - searchRequest.source(sourceBuilder); - client.search(searchRequest, listener); - } - - public void getMessages(ActionListener listener) { - memoryManager.getFinalInteractions(conversationId, LAST_N_INTERACTIONS, listener); - } - - public void getMessages(ActionListener listener, int size) { + @Override + public void getMessages(int size, ActionListener listener) { memoryManager.getFinalInteractions(conversationId, size, listener); } @@ -152,14 +151,15 @@ public void clear() { } @Override - public void remove(String id) { - throw new RuntimeException("remove method is not supported in ConversationIndexMemory"); - } - public void update(String messageId, Map updateContent, ActionListener updateListener) { getMemoryManager().updateInteraction(messageId, updateContent, updateListener); } + @Override + public void deleteInteractionAndTrace(String interactionId, ActionListener listener) { + memoryManager.deleteInteractionAndTrace(interactionId, listener); + } + public static class Factory implements Memory.Factory { private Client client; private MLIndicesHandler mlIndicesHandler; @@ -186,7 +186,7 @@ public void create(Map map, ActionListener listener) { + private void create(String name, String memoryId, String appType, ActionListener listener) { if (Strings.isEmpty(memoryId)) { memoryManager.createConversation(name, appType, ActionListener.wrap(r -> { create(r.getId(), listener); @@ -200,7 +200,7 @@ public void create(String name, String memoryId, String appType, ActionListener< } } - public void create(String memoryId, ActionListener listener) { + private void create(String memoryId, ActionListener listener) { listener .onResponse( new ConversationIndexMemory( diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java index efc84d8f8c..344e28e487 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java @@ -71,13 +71,13 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.MLTaskOutput; import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.MLMemoryManager; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index f6c3e3618e..115cc8bf03 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -50,10 +50,10 @@ import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.memory.ConversationIndexMemory; @@ -140,7 +140,7 @@ public void setup() { ActionListener> listener = invocation.getArgument(0); listener.onResponse(generateInteractions(2)); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + }).when(conversationIndexMemory).getMessages(messageHistoryLimitCapture.capture(), memoryInteractionCapture.capture()); when(conversationIndexMemory.getConversationId()).thenReturn("conversation_id"); when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); doAnswer(invocation -> { @@ -477,7 +477,7 @@ public void testChatHistoryExcludeOngoingQuestion() { interactionList.add(inProgressInteraction); listener.onResponse(interactionList); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + }).when(conversationIndexMemory).getMessages(messageHistoryLimitCapture.capture(), memoryInteractionCapture.capture()); HashMap params = new HashMap<>(); params.put(MESSAGE_HISTORY_LIMIT, "5"); @@ -533,7 +533,7 @@ private void testInteractions(String maxInteraction) { interactionList.add(inProgressInteraction); listener.onResponse(interactionList); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + }).when(conversationIndexMemory).getMessages(messageHistoryLimitCapture.capture(), memoryInteractionCapture.capture()); HashMap params = new HashMap<>(); params.put("verbose", "true"); @@ -563,7 +563,7 @@ public void testChatHistoryException() { ActionListener> listener = invocation.getArgument(0); listener.onFailure(new RuntimeException("Test Exception")); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + }).when(conversationIndexMemory).getMessages(messageHistoryLimitCapture.capture(), memoryInteractionCapture.capture()); HashMap params = new HashMap<>(); mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); @@ -904,7 +904,7 @@ public void testToolExecutionWithChatHistoryParameter() { interactionList.add(inProgressInteraction); listener.onResponse(interactionList); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); + }).when(conversationIndexMemory).getMessages(messageHistoryLimitCapture.capture(), memoryInteractionCapture.capture()); doAnswer(generateToolResponse("First tool response")) .when(firstTool) @@ -1171,4 +1171,136 @@ public void testConstructLLMParams_DefaultValues() { Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION)); Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE)); } + + @Test + public void testCreateMemoryAdapter_ConversationIndex() { + // Test that ConversationIndex memory type returns ConversationIndexMemory + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLMemorySpec memorySpec = MLMemorySpec.builder().type("conversation_index").build(); + MLAgent mlAgent = MLAgent + .builder() + .name("test_agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(memorySpec) + .build(); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.QUESTION, "test question"); + params.put(MLAgentExecutor.MEMORY_ID, "test_memory_id"); + + // Mock the memory factory + when(memoryMap.get("conversation_index")).thenReturn(memoryFactory); + + // Create a mock ConversationIndexMemory + org.opensearch.ml.engine.memory.ConversationIndexMemory mockMemory = Mockito + .mock(org.opensearch.ml.engine.memory.ConversationIndexMemory.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(mockMemory); + return null; + }).when(memoryFactory).create(anyString(), anyString(), anyString(), any()); + + // Test the createMemoryAdapter method + ActionListener testListener = new ActionListener() { + @Override + public void onResponse(Object result) { + // Verify that we get back a ConversationIndexMemory + assertTrue("Expected ConversationIndexMemory", result instanceof org.opensearch.ml.engine.memory.ConversationIndexMemory); + assertEquals("Memory should be the mocked instance", mockMemory, result); + } + + @Override + public void onFailure(Exception e) { + Assert.fail("Should not fail: " + e.getMessage()); + } + }; + + // This would normally be a private method call, but for testing we can verify the logic + // by checking that the correct memory type handling works through the public run method + // The actual test would need to be done through integration testing + } + + @Test + public void testCreateMemoryAdapter_AgenticMemory() { + // Test that agentic memory type returns AgenticMemoryAdapter + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + MLMemorySpec memorySpec = MLMemorySpec.builder().type("agentic_memory").build(); + MLAgent mlAgent = MLAgent + .builder() + .name("test_agent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(memorySpec) + .build(); + + Map params = new HashMap<>(); + params.put("memory_container_id", "test_container_id"); + params.put("session_id", "test_session_id"); + params.put("owner_id", "test_owner_id"); + + // This test verifies that the agentic memory path would be taken + // Full integration testing would require mocking the agentic memory services + assertNotNull("MLAgent should be created successfully", mlAgent); + assertEquals("Memory type should be agentic_memory", "agentic_memory", mlAgent.getMemory().getType()); + } + + @Test + public void testEnhancedChatMessage() { + // Test the enhanced ChatMessage format + ChatMessage userMessage = ChatMessage + .builder() + .id("msg_1") + .timestamp(java.time.Instant.now()) + .sessionId("session_123") + .role("user") + .content("Hello, how are you?") + .contentType("text") + .origin("agentic_memory") + .metadata(Map.of("confidence", 0.95)) + .build(); + + ChatMessage assistantMessage = ChatMessage + .builder() + .id("msg_2") + .timestamp(java.time.Instant.now()) + .sessionId("session_123") + .role("assistant") + .content("I'm doing well, thank you!") + .contentType("text") + .origin("agentic_memory") + .metadata(Map.of("confidence", 0.98)) + .build(); + + // Verify the enhanced ChatMessage structure + assertEquals("user", userMessage.getRole()); + assertEquals("text", userMessage.getContentType()); + assertEquals("agentic_memory", userMessage.getOrigin()); + assertNotNull(userMessage.getMetadata()); + assertEquals(0.95, userMessage.getMetadata().get("confidence")); + + assertEquals("assistant", assistantMessage.getRole()); + assertEquals("I'm doing well, thank you!", assistantMessage.getContent()); + } + + @Test + public void testSimpleChatHistoryTemplateEngine() { + // Test the new template engine + SimpleChatHistoryTemplateEngine templateEngine = new SimpleChatHistoryTemplateEngine(); + + List messages = List + .of( + ChatMessage.builder().role("user").content("What's the weather?").contentType("text").build(), + ChatMessage.builder().role("assistant").content("It's sunny today!").contentType("text").build(), + ChatMessage.builder().role("system").content("Weather data retrieved from API").contentType("context").build() + ); + + String chatHistory = templateEngine.buildSimpleChatHistory(messages); + + assertNotNull("Chat history should not be null", chatHistory); + assertTrue("Should contain user message", chatHistory.contains("Human: What's the weather?")); + assertTrue("Should contain assistant message", chatHistory.contains("Assistant: It's sunny today!")); + assertTrue("Should contain system context", chatHistory.contains("[Context] Weather data retrieved from API")); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java index cecb99f32e..a9d3ba6fcd 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunnerTest.java @@ -12,16 +12,12 @@ import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.utils.ToolUtils.buildToolParameters; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; import java.io.IOException; import java.util.Arrays; @@ -50,10 +46,10 @@ import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.utils.ToolUtils; import org.opensearch.ml.engine.indices.MLIndicesHandler; @@ -423,31 +419,31 @@ public void testWithMemoryNotSet() { assertEquals(SECOND_TOOL_RESPONSE, agentOutput.get(0).getResult()); } - @Test - public void testUpdateMemory() { - // Mocking MLMemorySpec - MLMemorySpec memorySpec = mock(MLMemorySpec.class); - when(memorySpec.getType()).thenReturn("memoryType"); - - // Mocking Memory Factory and Memory - - ConversationIndexMemory.Factory memoryFactory = new ConversationIndexMemory.Factory(); - memoryFactory.init(client, indicesHandler, memoryManager); - ActionListener listener = mock(ActionListener.class); - memoryFactory.create(Map.of(MEMORY_ID, "123", MEMORY_NAME, "name", APP_TYPE, "app"), listener); - - verify(listener).onResponse(isA(ConversationIndexMemory.class)); - - Map memoryFactoryMap = new HashMap<>(); - memoryFactoryMap.put("memoryType", memoryFactory); - mlFlowAgentRunner.setMemoryFactoryMap(memoryFactoryMap); - - // Execute the method under test - mlFlowAgentRunner.updateMemory(new HashMap<>(), memorySpec, "memoryId", "interactionId"); - - // Asserting that the Memory Manager's updateInteraction method was called - verify(memoryManager).updateInteraction(anyString(), anyMap(), any(ActionListener.class)); - } + // @Test + // public void testUpdateMemory() { + // // Mocking MLMemorySpec + // MLMemorySpec memorySpec = mock(MLMemorySpec.class); + // when(memorySpec.getType()).thenReturn("memoryType"); + // + // // Mocking Memory Factory and Memory + // + // ConversationIndexMemory.Factory memoryFactory = new ConversationIndexMemory.Factory(); + // memoryFactory.init(client, indicesHandler, memoryManager); + // ActionListener listener = mock(ActionListener.class); + // memoryFactory.create(Map.of(MEMORY_ID, "123", MEMORY_NAME, "name", APP_TYPE, "app"), listener); + // + // verify(listener).onResponse(isA(ConversationIndexMemory.class)); + // + // Map memoryFactoryMap = new HashMap<>(); + // memoryFactoryMap.put("memoryType", memoryFactory); + // mlFlowAgentRunner.setMemoryFactoryMap(memoryFactoryMap); + // + // // Execute the method under test + // mlFlowAgentRunner.updateMemory(new HashMap<>(), memorySpec, "memoryId", "interactionId"); + // + // // Asserting that the Memory Manager's updateInteraction method was called + // verify(memoryManager).updateInteraction(anyString(), anyMap(), any(ActionListener.class)); + // } @Test public void testRunWithUpdateFailure() { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 5245ccc320..a8be11c5f1 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -49,10 +49,10 @@ import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; @@ -153,7 +153,7 @@ public void setup() { ActionListener> listener = invocation.getArgument(0); listener.onResponse(generateInteractions()); return null; - }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), anyInt()); + }).when(conversationIndexMemory).getMessages(anyInt(), memoryInteractionCapture.capture()); // Setup memory manager doAnswer(invocation -> { @@ -370,7 +370,7 @@ public void testMessageHistoryLimits() { params.put("executor_message_history_limit", "3"); mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); - verify(conversationIndexMemory).getMessages(any(), eq(5)); + verify(conversationIndexMemory).getMessages(eq(5), any()); ArgumentCaptor executeCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class); verify(client).execute(eq(MLExecuteTaskAction.INSTANCE), executeCaptor.capture(), any()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java new file mode 100644 index 0000000000..5d84d01f4b --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesResponse; +import org.opensearch.ml.common.transport.session.MLCreateSessionAction; +import org.opensearch.ml.common.transport.session.MLCreateSessionResponse; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.transport.client.Client; + +public class AgenticConversationMemoryTest { + + @Mock + private Client client; + + @Mock + private MLIndicesHandler mlIndicesHandler; + + @Mock + private MLMemoryManager memoryManager; + + private AgenticConversationMemory agenticMemory; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + + agenticMemory = new AgenticConversationMemory(client, "test_conversation_id", "test_memory_container_id"); + } + + @Test + public void testGetType() { + assert agenticMemory.getType().equals("agentic_conversation"); + } + + @Test + public void testSaveMessage() { + ConversationIndexMessage message = ConversationIndexMessage + .conversationIndexMessageBuilder() + .sessionId("test_session") + .question("What is AI?") + .response("AI is artificial intelligence") + .finalAnswer(true) + .build(); + + // Mock memory container save (primary path) + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(MLAddMemoriesResponse.builder().workingMemoryId("working_mem_123").build()); + return null; + }).when(client).execute(eq(MLAddMemoriesAction.INSTANCE), any(), any()); + + ActionListener testListener = ActionListener.wrap(response -> { + // Response should contain the working memory ID + assert response.getId().equals("working_mem_123"); + }, e -> { throw new RuntimeException("Should not fail", e); }); + + agenticMemory.save(message, null, null, "test_action", testListener); + + // Verify only memory container save was called (not conversation index) + verify(client, times(1)).execute(eq(MLAddMemoriesAction.INSTANCE), any(), any()); + verify(memoryManager, never()).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + } + + @Test + public void testFactoryCreate() { + AgenticConversationMemory.Factory factory = new AgenticConversationMemory.Factory(); + factory.init(client, mlIndicesHandler, memoryManager); + + Map params = new HashMap<>(); + params.put("memory_id", "test_memory_id"); + params.put("memory_name", "Test Memory"); + params.put("app_type", "conversational"); + params.put("memory_container_id", "test_container_id"); + + ActionListener listener = ActionListener.wrap(memory -> { + assert memory.getMemoryContainerId().equals("test_container_id"); + }, e -> { throw new RuntimeException("Should not fail", e); }); + + factory.create(params, listener); + } + + @Test + public void testFactoryCreateWithNewSession() { + AgenticConversationMemory.Factory factory = new AgenticConversationMemory.Factory(); + factory.init(client, mlIndicesHandler, memoryManager); + + // Mock session creation + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(MLCreateSessionResponse.builder().sessionId("new_session_123").status("created").build()); + return null; + }).when(client).execute(eq(MLCreateSessionAction.INSTANCE), any(), any()); + + Map params = new HashMap<>(); + params.put("memory_name", "New Session"); + params.put("app_type", "conversational"); + params.put("memory_container_id", "test_container_id"); + + ActionListener listener = ActionListener.wrap(memory -> { + assert memory.getConversationId().equals("new_session_123"); + assert memory.getMemoryContainerId().equals("test_container_id"); + }, e -> { throw new RuntimeException("Should not fail", e); }); + + factory.create(params, listener); + + // Verify session creation was called + verify(client, times(1)).execute(eq(MLCreateSessionAction.INSTANCE), any(), any()); + } + + @Test + public void testSaveWithoutMemoryContainerId() { + AgenticConversationMemory memoryWithoutContainer = new AgenticConversationMemory( + client, + "test_conversation_id", + null // No memory container ID = should fail + ); + + ConversationIndexMessage message = ConversationIndexMessage + .conversationIndexMessageBuilder() + .sessionId("test_session") + .question("What is AI?") + .response("AI is artificial intelligence") + .build(); + + ActionListener testListener = ActionListener.wrap(response -> { + throw new RuntimeException("Should have failed without memory container ID"); + }, e -> { + // Expected to fail + assert e instanceof IllegalStateException; + assert e.getMessage().contains("Memory container ID is not configured"); + }); + + memoryWithoutContainer.save(message, null, null, "test_action", testListener); + + // Verify no API calls were made + verify(memoryManager, never()).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); + verify(client, never()).execute(eq(MLAddMemoriesAction.INSTANCE), any(), any()); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java index d1ac123d7c..3c400b40cb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/ConversationIndexMemoryTest.java @@ -20,10 +20,7 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.index.shard.ShardId; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; @@ -71,29 +68,29 @@ public void getType() { Assert.assertEquals(indexMemory.getType(), ConversationIndexMemory.TYPE); } - @Test - public void save() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(true); - return null; - }).when(indicesHandler).initMemoryMessageIndex(any()); - indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); - - verify(indicesHandler).initMemoryMessageIndex(any()); - } - - @Test - public void save4() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onFailure(new RuntimeException()); - return null; - }).when(indicesHandler).initMemoryMessageIndex(any()); - indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); - - verify(indicesHandler).initMemoryMessageIndex(any()); - } + // @Test + // public void save() { + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(0); + // listener.onResponse(true); + // return null; + // }).when(indicesHandler).initMemoryMessageIndex(any()); + // indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); + // + // verify(indicesHandler).initMemoryMessageIndex(any()); + // } + + // @Test + // public void save4() { + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(0); + // listener.onFailure(new RuntimeException()); + // return null; + // }).when(indicesHandler).initMemoryMessageIndex(any()); + // indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false)); + // + // verify(indicesHandler).initMemoryMessageIndex(any()); + // } @Test public void save1() { @@ -119,66 +116,54 @@ public void save6() { verify(memoryManager).createInteraction(any(), any(), any(), any(), any(), any(), any(), any(), any()); } - @Test - public void save2() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(Boolean.TRUE); - return null; - }).when(indicesHandler).initMemoryMessageIndex(any()); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); - return null; - }).when(client).index(any(), any()); - ActionListener actionListener = mock(ActionListener.class); - indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); - - verify(actionListener).onResponse(isA(IndexResponse.class)); - } - - @Test - public void save3() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onFailure(new RuntimeException()); - return null; - }).when(indicesHandler).initMemoryMessageIndex(any()); - ActionListener actionListener = mock(ActionListener.class); - indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); - - verify(actionListener).onFailure(isA(RuntimeException.class)); - } - - @Test - public void save5() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(Boolean.FALSE); - return null; - }).when(indicesHandler).initMemoryMessageIndex(any()); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); - return null; - }).when(client).index(any(), any()); - ActionListener actionListener = mock(ActionListener.class); - indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); - - verify(actionListener).onFailure(isA(RuntimeException.class)); - } - - @Test - public void getMessages() { - ActionListener listener = mock(ActionListener.class); - indexMemory.getMessages("test_id", listener); - } - - @Test - public void getMessages1() { - ActionListener listener = mock(ActionListener.class); - indexMemory.getMessages(listener); - } + // @Test + // public void save2() { + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(0); + // listener.onResponse(Boolean.TRUE); + // return null; + // }).when(indicesHandler).initMemoryMessageIndex(any()); + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(1); + // listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); + // return null; + // }).when(client).index(any(), any()); + // ActionListener actionListener = mock(ActionListener.class); + // indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + // + // verify(actionListener).onResponse(isA(IndexResponse.class)); + // } + + // @Test + // public void save3() { + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(0); + // listener.onFailure(new RuntimeException()); + // return null; + // }).when(indicesHandler).initMemoryMessageIndex(any()); + // ActionListener actionListener = mock(ActionListener.class); + // indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + // + // verify(actionListener).onFailure(isA(RuntimeException.class)); + // } + + // @Test + // public void save5() { + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(0); + // listener.onResponse(Boolean.FALSE); + // return null; + // }).when(indicesHandler).initMemoryMessageIndex(any()); + // doAnswer(invocation -> { + // ActionListener listener = invocation.getArgument(1); + // listener.onResponse(new IndexResponse(new ShardId("test", "test", 1), "test", 1l, 1l, 1l, true)); + // return null; + // }).when(client).index(any(), any()); + // ActionListener actionListener = mock(ActionListener.class); + // indexMemory.save("test_id", new ConversationIndexMessage("test", "123", "question", "response", false), actionListener); + // + // verify(actionListener).onFailure(isA(RuntimeException.class)); + // } @Test public void clear() { @@ -187,12 +172,12 @@ public void clear() { indexMemory.clear(); } - @Test - public void remove() { - exceptionRule.expect(RuntimeException.class); - exceptionRule.expectMessage("remove method is not supported in ConversationIndexMemory"); - indexMemory.remove("test_id"); - } + // @Test + // public void remove() { + // exceptionRule.expect(RuntimeException.class); + // exceptionRule.expectMessage("remove method is not supported in ConversationIndexMemory"); + // indexMemory.remove("test_id"); + // } @Test public void factory_create_emptyMap() { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 62de34961e..d9fe945498 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -171,11 +171,11 @@ import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; import org.opensearch.ml.common.input.parameter.sample.SampleAlgoParams; import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters; +import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.settings.MLCommonsSettings; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.spi.MLCommonsExtension; -import org.opensearch.ml.common.spi.memory.Memory; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction; @@ -268,6 +268,7 @@ import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.engine.indices.MLInputDatasetHandler; +import org.opensearch.ml.engine.memory.AgenticConversationMemory; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.MLMemoryManager; import org.opensearch.ml.engine.tools.AgentTool; @@ -843,6 +844,10 @@ public Collection createComponents( conversationIndexMemoryFactory.init(client, mlIndicesHandler, memoryManager); memoryFactoryMap.put(ConversationIndexMemory.TYPE, conversationIndexMemoryFactory); + AgenticConversationMemory.Factory agenticConversationMemoryFactory = new AgenticConversationMemory.Factory(); + conversationIndexMemoryFactory.init(client, mlIndicesHandler, memoryManager); + memoryFactoryMap.put(AgenticConversationMemory.TYPE, agenticConversationMemoryFactory); + MLAgentExecutor agentExecutor = new MLAgentExecutor( client, sdkClient, diff --git a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java b/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java deleted file mode 100644 index 3615384fce..0000000000 --- a/spi/src/main/java/org/opensearch/ml/common/spi/memory/Memory.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.spi.memory; - -import java.util.Map; - -import org.opensearch.core.action.ActionListener; - -/** - * A general memory interface. - * @param - */ -public interface Memory { - - /** - * Get memory type. - * @return - */ - String getType(); - - /** - * Save message to id. - * @param id memory id - * @param message message to be saved - */ - default void save(String id, T message) {} - - default void save(String id, T message, ActionListener listener) {} - - /** - * Get messages of memory id. - * @param id memory id - * @return - */ - default T[] getMessages(String id) { - return null; - } - - default void getMessages(String id, ActionListener listener) {} - - /** - * Clear all memory. - */ - void clear(); - - /** - * Remove memory of specific id. - * @param id memory id - */ - void remove(String id); - - interface Factory { - /** - * Create an instance of this Memory. - * - * @param params Parameters for the memory - * @param listener Action listen for the memory creation action - */ - void create(Map params, ActionListener listener); - } -} From b2f53541785ae059dc60ad1beb08dc48d88fb40b Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 24 Oct 2025 23:27:01 -0700 Subject: [PATCH 10/58] agentic conversation memory Signed-off-by: Yaliang Wu --- .../ml/common/agent/MLMemorySpec.java | 16 +- .../MemoryConfigurationTests.java | 10 +- .../engine/algorithms/agent/AgentUtils.java | 14 + .../algorithms/agent/MLAgentExecutor.java | 8 +- .../algorithms/agent/MLChatAgentRunner.java | 9 +- .../MLConversationalFlowAgentRunner.java | 12 +- .../remote/HttpJsonConnectorExecutor.java | 1 - .../memory/AgenticConversationMemory.java | 418 +++++++++++++++++- .../memory/MemorySearchService.java | 2 +- .../ml/plugin/MachineLearningPlugin.java | 2 +- 10 files changed, 453 insertions(+), 39 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java index bba24db6c4..7476ad351c 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java @@ -26,32 +26,37 @@ public class MLMemorySpec implements ToXContentObject { public static final String MEMORY_TYPE_FIELD = "type"; public static final String WINDOW_SIZE_FIELD = "window_size"; public static final String SESSION_ID_FIELD = "session_id"; + public static final String MEMORY_CONTAINER_ID_FIELD = "memory_container_id"; private String type; @Setter private String sessionId; private Integer windowSize; + private String memoryContainerId; @Builder(toBuilder = true) - public MLMemorySpec(String type, String sessionId, Integer windowSize) { + public MLMemorySpec(String type, String sessionId, Integer windowSize, String memoryContainerId) { if (type == null) { throw new IllegalArgumentException("agent name is null"); } this.type = type; this.sessionId = sessionId; this.windowSize = windowSize; + this.memoryContainerId = memoryContainerId; } public MLMemorySpec(StreamInput input) throws IOException { type = input.readString(); sessionId = input.readOptionalString(); windowSize = input.readOptionalInt(); + memoryContainerId = input.readOptionalString(); } public void writeTo(StreamOutput out) throws IOException { out.writeString(type); out.writeOptionalString(sessionId); out.writeOptionalInt(windowSize); + out.writeOptionalString(memoryContainerId); } @Override @@ -64,6 +69,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (sessionId != null) { builder.field(SESSION_ID_FIELD, sessionId); } + if (memoryContainerId != null) { + builder.field(MEMORY_CONTAINER_ID_FIELD, memoryContainerId); + } builder.endObject(); return builder; } @@ -72,6 +80,7 @@ public static MLMemorySpec parse(XContentParser parser) throws IOException { String type = null; String sessionId = null; Integer windowSize = null; + String memoryContainerId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -88,12 +97,15 @@ public static MLMemorySpec parse(XContentParser parser) throws IOException { case WINDOW_SIZE_FIELD: windowSize = parser.intValue(); break; + case MEMORY_CONTAINER_ID_FIELD: + memoryContainerId = parser.text(); + break; default: parser.skipChildren(); break; } } - return MLMemorySpec.builder().type(type).sessionId(sessionId).windowSize(windowSize).build(); + return MLMemorySpec.builder().type(type).sessionId(sessionId).windowSize(windowSize).memoryContainerId(memoryContainerId).build(); } public static MLMemorySpec fromStream(StreamInput in) throws IOException { diff --git a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryConfigurationTests.java b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryConfigurationTests.java index c48a68f933..e2dd298d35 100644 --- a/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryConfigurationTests.java +++ b/common/src/test/java/org/opensearch/ml/common/memorycontainer/MemoryConfigurationTests.java @@ -1088,11 +1088,11 @@ public void testMemoryConfiguration_WithRemoteStore() { RemoteStore remoteStore = RemoteStore.builder().type("aoss").connectorId("ySf08JkBym-3qj1O2uub").build(); MemoryConfiguration config = MemoryConfiguration - .builder() - .indexPrefix("test") - .useSystemIndex(false) - .remoteStore(remoteStore) - .build(); + .builder() + .indexPrefix("test") + .useSystemIndex(false) + .remoteStore(remoteStore) + .build(); assertNotNull(config.getRemoteStore()); assertEquals("aoss", config.getRemoteStore().getType()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 7de547127a..3df1cb0c66 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD; import static org.opensearch.ml.common.CommonValue.MCP_CONNECTOR_ID_FIELD; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.agent.MLMemorySpec.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.isJson; @@ -29,6 +30,7 @@ import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES; import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.RESPONSE_FIELD; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS; import java.io.IOException; @@ -83,6 +85,7 @@ import org.opensearch.ml.engine.algorithms.remote.McpStreamableHttpConnectorExecutor; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.function_calling.FunctionCalling; +import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.tools.McpSseTool; import org.opensearch.ml.engine.tools.McpStreamableHttpTool; import org.opensearch.remote.metadata.client.GetDataObjectRequest; @@ -1014,4 +1017,15 @@ public static Tool createTool(Map toolFactories, Map createMemoryParams(String question, String memoryId, String appType, MLAgent mlAgent) { + Map memoryParams = new HashMap<>(); + memoryParams.put(ConversationIndexMemory.MEMORY_NAME, question); + memoryParams.put(ConversationIndexMemory.MEMORY_ID, memoryId); + memoryParams.put(APP_TYPE, appType); + if (mlAgent.getMemory().getMemoryContainerId() != null) { + memoryParams.put(MEMORY_CONTAINER_ID_FIELD, mlAgent.getMemory().getMemoryContainerId()); + } + return memoryParams; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 9c0ad8a3b2..844c4b6136 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -16,6 +16,7 @@ import static org.opensearch.ml.common.output.model.ModelTensorOutput.INFERENCE_RESULT_FIELD; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE; import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createMemoryParams; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; import java.security.AccessController; @@ -248,11 +249,8 @@ public void execute(Input input, ActionListener listener, TransportChann && (memoryId == null || parentInteractionId == null)) { Memory.Factory> memoryFactory = memoryFactoryMap.get(memorySpec.getType()); - Map params = new HashMap<>(); - params.put(ConversationIndexMemory.MEMORY_NAME, question); - params.put(ConversationIndexMemory.MEMORY_ID, memoryId); - params.put(APP_TYPE, appType); - memoryFactory.create(params, ActionListener.wrap(memory -> { + Map memoryParams = createMemoryParams(question, memoryId, appType, mlAgent); + memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { inputDataSet.getParameters().put(MEMORY_ID, memory.getId()); // get question for regenerate if (regenerateInteractionId != null) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index dad28a44f0..9ee9bd7658 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -7,7 +7,6 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; -import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.processTextDoc; import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput; @@ -25,6 +24,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.VERBOSE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.cleanUpResource; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.constructToolParams; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createMemoryParams; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs; @@ -35,8 +35,6 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; import static org.opensearch.ml.engine.tools.ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY; import java.security.PrivilegedActionException; @@ -188,9 +186,10 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener String chatHistoryResponseTemplate = params.get(CHAT_HISTORY_RESPONSE_TEMPLATE); int messageHistoryLimit = getMessageHistoryLimit(params); - Memory.Factory> factory = memoryFactoryMap.get(memoryType); + Memory.Factory> memoryFactory = memoryFactoryMap.get(memoryType); - factory.create(Map.of(MEMORY_ID, memoryId, MEMORY_NAME, title, APP_TYPE, appType), ActionListener.wrap(memory -> { + Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent); + memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { // TODO: call runAgent directly if messageHistoryLimit == 0 memory.getMessages(messageHistoryLimit, ActionListener.>wrap(r -> { List messageList = new ArrayList<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index 86b46bbf56..87e7b761f0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -15,16 +15,14 @@ import static org.opensearch.ml.common.utils.ToolUtils.getToolName; import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createMemoryParams; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTool; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMessageHistoryLimit; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; import java.security.PrivilegedActionException; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -113,12 +111,8 @@ public void run(MLAgent mlAgent, Map params, ActionListener> memoryFactory = memoryFactoryMap.get(memoryType); - Map createMemoryParams = new HashMap<>(); - params.put(MEMORY_NAME, title); - params.put(MEMORY_ID, memoryId); - params.put(APP_TYPE, appType); - - memoryFactory.create(createMemoryParams, ActionListener.wrap(memory -> { + Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent); + memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { memory.getMessages(messageHistoryLimit, ActionListener.>wrap(r -> { List messageList = new ArrayList<>(); for (Interaction next : r) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 7c6e89e076..45b318bc6c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -13,7 +13,6 @@ import static software.amazon.awssdk.http.SdkHttpMethod.POST; import static software.amazon.awssdk.http.SdkHttpMethod.PUT; -import java.net.URL; import java.security.AccessController; import java.security.PrivilegedExceptionAction; import java.time.Duration; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java index 7cd2484078..81c58d3aaa 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java @@ -9,18 +9,40 @@ import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.memory.Message; +import org.opensearch.ml.common.memorycontainer.MLWorkingMemory; +import org.opensearch.ml.common.memorycontainer.MemoryType; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLGetMemoryAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLGetMemoryRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesRequest; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryAction; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryInput; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryRequest; import org.opensearch.ml.common.transport.session.MLCreateSessionAction; import org.opensearch.ml.common.transport.session.MLCreateSessionInput; import org.opensearch.ml.common.transport.session.MLCreateSessionRequest; import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; import org.opensearch.transport.client.Client; import lombok.Getter; @@ -34,23 +56,35 @@ @Getter public class AgenticConversationMemory implements Memory { - public static final String TYPE = "agentic_memory"; + public static final String TYPE = "agentic_conversation"; + private static final String SESSION_ID_FIELD = "session_id"; + private static final String CREATED_TIME_FIELD = "created_time"; - public AgenticConversationMemory(Client client, String memoryId, String memoryContainerId) {} + private final Client client; + private final String conversationId; + private final String memoryContainerId; + + public AgenticConversationMemory(Client client, String memoryId, String memoryContainerId) { + this.client = client; + this.conversationId = memoryId; + this.memoryContainerId = memoryContainerId; + } @Override public String getType() { - return ""; + return TYPE; } @Override public String getId() { - return ""; + return conversationId; } @Override public void save(Message message, String parentId, Integer traceNum, String action) { - Memory.super.save(message, parentId, traceNum, action); + this.save(message, parentId, traceNum, action, ActionListener.wrap(r -> { + log.info("Saved message to agentic memory, session id: {}, working memory id: {}", conversationId, r.getId()); + }, e -> { log.error("Failed to save message to agentic memory", e); })); } @Override @@ -61,27 +95,391 @@ public void save( String action, ActionListener listener ) { - Memory.super.save(message, parentId, traceNum, action, listener); + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener + .onFailure( + new IllegalStateException( + "Memory container ID is not configured for this AgenticConversationMemory. " + + "Cannot save messages without a valid memory container." + ) + ); + return; + } + + ConversationIndexMessage msg = (ConversationIndexMessage) message; + + // Build namespace with session_id + Map namespace = new HashMap<>(); + namespace.put(SESSION_ID_FIELD, conversationId); + + // Simple rule matching ConversationIndexMemory: + // - If traceNum != null → it's a trace + // - If traceNum == null → it's a message + boolean isTrace = (traceNum != null); + + Map metadata = new HashMap<>(); + Map structuredData = new HashMap<>(); + + // Store data in structured_data format matching conversation index + structuredData.put("input", msg.getQuestion() != null ? msg.getQuestion() : ""); + structuredData.put("response", msg.getResponse() != null ? msg.getResponse() : ""); + + if (isTrace) { + // This is a trace (tool usage or intermediate step) + metadata.put("type", "trace"); + if (parentId != null) { + metadata.put("parent_message_id", parentId); + structuredData.put("parent_message_id", parentId); + } + metadata.put("trace_number", String.valueOf(traceNum)); + structuredData.put("trace_number", traceNum); + if (action != null) { + metadata.put("origin", action); + structuredData.put("origin", action); + } + } else { + // This is a final message (Q&A pair) + metadata.put("type", "message"); + if (msg.getFinalAnswer() != null) { + structuredData.put("final_answer", msg.getFinalAnswer()); + } + } + + // Add timestamps + java.time.Instant now = java.time.Instant.now(); + structuredData.put("create_time", now.toString()); + structuredData.put("updated_time", now.toString()); + + // Create MLAddMemoriesInput + MLAddMemoriesInput input = MLAddMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .structuredData(structuredData) + .messageId(traceNum) // Store trace number in messageId field (null for messages) + .namespace(namespace) + .metadata(metadata) + .infer(false) // Don't infer long-term memory by default + .build(); + + MLAddMemoriesRequest request = MLAddMemoriesRequest.builder().mlAddMemoryInput(input).build(); + + // Execute the add memories action + client.execute(MLAddMemoriesAction.INSTANCE, request, ActionListener.wrap(response -> { + // Convert MLAddMemoriesResponse to CreateInteractionResponse + CreateInteractionResponse interactionResponse = new CreateInteractionResponse(response.getWorkingMemoryId()); + listener.onResponse(interactionResponse); + }, e -> { + log.error("Failed to add memories to memory container", e); + listener.onFailure(e); + })); } @Override public void update(String messageId, Map updateContent, ActionListener updateListener) { - Memory.super.update(messageId, updateContent, updateListener); + if (Strings.isNullOrEmpty(memoryContainerId)) { + updateListener.onFailure(new IllegalStateException("Memory container ID is not configured for this AgenticConversationMemory")); + return; + } + + // Step 1: Get the existing working memory to retrieve current structured_data + MLGetMemoryRequest getRequest = MLGetMemoryRequest + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) + .memoryId(messageId) + .build(); + + client.execute(MLGetMemoryAction.INSTANCE, getRequest, ActionListener.wrap(getResponse -> { + // Step 2: Extract existing structured_data and merge with updates + MLWorkingMemory workingMemory = getResponse.getWorkingMemory(); + if (workingMemory == null) { + updateListener.onFailure(new IllegalStateException("Working memory not found for id: " + messageId)); + return; + } + + Map structuredData = workingMemory.getStructuredData(); + if (structuredData == null) { + structuredData = new HashMap<>(); + } else { + // Create a mutable copy + structuredData = new HashMap<>(structuredData); + } + + // Step 3: Merge update content into structured_data + // The updateContent contains fields like "response" and "additional_info" + // These should be stored in structured_data + for (Map.Entry entry : updateContent.entrySet()) { + structuredData.put(entry.getKey(), entry.getValue()); + } + + // Update the timestamp + // structuredData.put("updated_time", java.time.Instant.now().toString()); + + // Step 4: Create update request with merged structured_data + Map finalUpdateContent = new HashMap<>(); + finalUpdateContent.put("structured_data", structuredData); + + MLUpdateMemoryInput input = MLUpdateMemoryInput.builder().updateContent(finalUpdateContent).build(); + + MLUpdateMemoryRequest updateRequest = MLUpdateMemoryRequest + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) + .memoryId(messageId) + .mlUpdateMemoryInput(input) + .build(); + + // Step 5: Execute the update + client.execute(MLUpdateMemoryAction.INSTANCE, updateRequest, ActionListener.wrap(indexResponse -> { + // Convert IndexResponse to UpdateResponse + UpdateResponse updateResponse = new UpdateResponse( + indexResponse.getShardInfo(), + indexResponse.getShardId(), + indexResponse.getId(), + indexResponse.getSeqNo(), + indexResponse.getPrimaryTerm(), + indexResponse.getVersion(), + indexResponse.getResult() + ); + updateListener.onResponse(updateResponse); + }, e -> { + log.error("Failed to update memory in memory container", e); + updateListener.onFailure(e); + })); + }, e -> { + log.error("Failed to get existing memory for update", e); + updateListener.onFailure(e); + })); } @Override public void getMessages(int size, ActionListener> listener) { - Memory.super.getMessages(size, listener); + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener.onFailure(new IllegalStateException("Memory container ID is not configured for this AgenticConversationMemory")); + return; + } + + // Build search query for working memory by session_id, filtering only final messages (not traces) + // Match ConversationIndexMemory pattern: exclude entries with trace_number + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + boolQuery.must(QueryBuilders.termQuery("namespace." + SESSION_ID_FIELD, conversationId)); + boolQuery.mustNot(QueryBuilders.existsQuery("structured_data.trace_number")); // Exclude traces + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(boolQuery); + searchSourceBuilder.size(size); + searchSourceBuilder.sort(CREATED_TIME_FIELD, SortOrder.ASC); + + MLSearchMemoriesInput searchInput = MLSearchMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) + .searchSourceBuilder(searchSourceBuilder) + .build(); + + MLSearchMemoriesRequest request = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(searchInput).tenantId(null).build(); + + client.execute(MLSearchMemoriesAction.INSTANCE, request, ActionListener.wrap(searchResponse -> { + List interactions = parseSearchResponseToInteractions(searchResponse); + listener.onResponse(interactions); + }, e -> { + log.error("Failed to search memories in memory container", e); + listener.onFailure(e); + })); + } + + private List parseSearchResponseToInteractions(SearchResponse searchResponse) { + List interactions = new ArrayList<>(); + for (SearchHit hit : searchResponse.getHits().getHits()) { + Map sourceMap = hit.getSourceAsMap(); + + // Extract structured_data which contains the interaction data + @SuppressWarnings("unchecked") + Map structuredData = (Map) sourceMap.get("structured_data"); + + if (structuredData != null) { + String input = (String) structuredData.get("input"); + String response = (String) structuredData.get("response"); + + // Extract timestamps + Long createdTimeMs = (Long) sourceMap.get("created_time"); + Long updatedTimeMs = (Long) sourceMap.get("last_updated_time"); + + // Parse create_time from structured_data if available + String createTimeStr = (String) structuredData.get("create_time"); + String updatedTimeStr = (String) structuredData.get("updated_time"); + + java.time.Instant createTime = null; + java.time.Instant updatedTime = null; + + if (createTimeStr != null) { + try { + createTime = java.time.Instant.parse(createTimeStr); + } catch (Exception e) { + log.warn("Failed to parse create_time from structured_data", e); + } + } + if (updatedTimeStr != null) { + try { + updatedTime = java.time.Instant.parse(updatedTimeStr); + } catch (Exception e) { + log.warn("Failed to parse updated_time from structured_data", e); + } + } + + // Fallback to document timestamps if structured_data timestamps not available + if (createTime == null && createdTimeMs != null) { + createTime = java.time.Instant.ofEpochMilli(createdTimeMs); + } + if (updatedTime == null && updatedTimeMs != null) { + updatedTime = java.time.Instant.ofEpochMilli(updatedTimeMs); + } + + // Extract metadata + @SuppressWarnings("unchecked") + Map metadata = (Map) sourceMap.get("metadata"); + String parentInteractionId = metadata != null ? metadata.get("parent_message_id") : null; + + // Create Interaction object + if (input != null || response != null) { + Interaction interaction = Interaction + .builder() + .id(hit.getId()) + .conversationId(conversationId) + .createTime(createTime != null ? createTime : java.time.Instant.now()) + .updatedTime(updatedTime) + .input(input != null ? input : "") + .response(response != null ? response : "") + .origin("agentic_memory") + .promptTemplate(null) + .additionalInfo(null) + .parentInteractionId(parentInteractionId) + .traceNum(null) // Messages don't have trace numbers + .build(); + interactions.add(interaction); + } + } + } + return interactions; } @Override public void clear() { - + throw new UnsupportedOperationException("clear method is not supported in AgenticConversationMemory"); } @Override - public void deleteInteractionAndTrace(String regenerateInteractionId, ActionListener wrap) { + public void deleteInteractionAndTrace(String interactionId, ActionListener listener) { + // For now, delegate to a simple implementation + // In the future, this could use MLDeleteMemoryAction + log.warn("deleteInteractionAndTrace is not fully implemented for AgenticConversationMemory"); + listener.onResponse(false); + } + /** + * Get traces (intermediate steps/tool usage) for a specific parent message + * @param parentMessageId The parent message ID + * @param listener Action listener for the traces + */ + public void getTraces(String parentMessageId, ActionListener> listener) { + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener.onFailure(new IllegalStateException("Memory container ID is not configured for this AgenticConversationMemory")); + return; + } + + // Build search query for traces by parent_message_id + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + boolQuery.must(QueryBuilders.termQuery("namespace." + SESSION_ID_FIELD, conversationId)); + boolQuery.must(QueryBuilders.termQuery("metadata.type", "trace")); // Only get traces + boolQuery.must(QueryBuilders.termQuery("metadata.parent_message_id", parentMessageId)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(boolQuery); + searchSourceBuilder.size(1000); // Get all traces for this message + searchSourceBuilder.sort("message_id", SortOrder.ASC); // Sort by trace number + + MLSearchMemoriesInput searchInput = MLSearchMemoriesInput + .builder() + .memoryContainerId(memoryContainerId) + .memoryType(MemoryType.WORKING) + .searchSourceBuilder(searchSourceBuilder) + .build(); + + MLSearchMemoriesRequest request = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(searchInput).tenantId(null).build(); + + client.execute(MLSearchMemoriesAction.INSTANCE, request, ActionListener.wrap(searchResponse -> { + List traces = parseSearchResponseToTraces(searchResponse); + listener.onResponse(traces); + }, e -> { + log.error("Failed to search traces in memory container", e); + listener.onFailure(e); + })); + } + + private List parseSearchResponseToTraces(SearchResponse searchResponse) { + List traces = new ArrayList<>(); + for (SearchHit hit : searchResponse.getHits().getHits()) { + Map sourceMap = hit.getSourceAsMap(); + + // Extract structured_data which contains the trace data + @SuppressWarnings("unchecked") + Map structuredData = (Map) sourceMap.get("structured_data"); + + if (structuredData != null) { + String input = (String) structuredData.get("input"); + String response = (String) structuredData.get("response"); + String origin = (String) structuredData.get("origin"); + String parentMessageId = (String) structuredData.get("parent_message_id"); + + // Extract trace number + Integer traceNum = null; + Object traceNumObj = structuredData.get("trace_number"); + if (traceNumObj instanceof Integer) { + traceNum = (Integer) traceNumObj; + } else if (traceNumObj instanceof String) { + try { + traceNum = Integer.parseInt((String) traceNumObj); + } catch (NumberFormatException e) { + log.warn("Failed to parse trace_number", e); + } + } + + // Also check message_id field which stores trace number + Integer messageId = (Integer) sourceMap.get("message_id"); + if (traceNum == null && messageId != null) { + traceNum = messageId; + } + + // Extract timestamps + Long createdTimeMs = (Long) sourceMap.get("created_time"); + Long updatedTimeMs = (Long) sourceMap.get("last_updated_time"); + + java.time.Instant createTime = createdTimeMs != null + ? java.time.Instant.ofEpochMilli(createdTimeMs) + : java.time.Instant.now(); + java.time.Instant updatedTime = updatedTimeMs != null ? java.time.Instant.ofEpochMilli(updatedTimeMs) : null; + + // Create Interaction object for trace + if (input != null || response != null) { + Interaction trace = Interaction + .builder() + .id(hit.getId()) + .conversationId(conversationId) + .createTime(createTime) + .updatedTime(updatedTime) + .input(input != null ? input : "") + .response(response != null ? response : "") + .origin(origin != null ? origin : "") + .promptTemplate(null) + .additionalInfo(null) + .parentInteractionId(parentMessageId) + .traceNum(traceNum) + .build(); + traces.add(trace); + } + } + } + return traces; } /** diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java index 2300ffbec9..f00ad80cd1 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java @@ -104,7 +104,7 @@ private void searchFactsSequentially( input.getMemoryContainerId() ); - log.debug("Searching for similar facts"); + log.debug("Searching for similar facts"); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(queryBuilder); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index d9fe945498..dc47e87059 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -845,7 +845,7 @@ public Collection createComponents( memoryFactoryMap.put(ConversationIndexMemory.TYPE, conversationIndexMemoryFactory); AgenticConversationMemory.Factory agenticConversationMemoryFactory = new AgenticConversationMemory.Factory(); - conversationIndexMemoryFactory.init(client, mlIndicesHandler, memoryManager); + agenticConversationMemoryFactory.init(client); memoryFactoryMap.put(AgenticConversationMemory.TYPE, agenticConversationMemoryFactory); MLAgentExecutor agentExecutor = new MLAgentExecutor( From 94d97e747392fa52da2008acdf7f270e218293d6 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Sat, 25 Oct 2025 02:52:29 -0700 Subject: [PATCH 11/58] fix remote store issues Signed-off-by: Yaliang Wu --- .../remote/AwsConnectorExecutor.java | 4 +- .../remote/HttpJsonConnectorExecutor.java | 4 +- .../memory/AgenticConversationMemory.java | 78 +++++++++++++++++-- .../ml/helper/MemoryContainerHelper.java | 5 +- 4 files changed, 79 insertions(+), 12 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index f500ae32d1..8fad172811 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -26,6 +26,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.common.collect.Tuple; import org.opensearch.common.util.TokenBucket; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; @@ -123,6 +124,7 @@ public void invokeRemoteService( default: throw new IllegalArgumentException("unsupported http method"); } + ThreadContext.StoredContext storedContext = client.threadPool().getThreadContext().newStoredContext(true); AsyncExecuteRequest executeRequest = AsyncExecuteRequest .builder() .request(signRequest(request)) @@ -130,7 +132,7 @@ public void invokeRemoteService( .responseHandler( new MLSdkAsyncHttpResponseHandler( executionContext, - actionListener, + ActionListener.runBefore(actionListener, storedContext::restore), // Restore context before calling listener, parameters, connector, scriptService, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 45b318bc6c..bc2446034c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -25,6 +25,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.common.collect.Tuple; import org.opensearch.common.util.TokenBucket; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; @@ -122,6 +123,7 @@ public void invokeRemoteService( default: throw new IllegalArgumentException("unsupported http method"); } + ThreadContext.StoredContext storedContext = client.threadPool().getThreadContext().newStoredContext(true); AsyncExecuteRequest executeRequest = AsyncExecuteRequest .builder() .request(request) @@ -129,7 +131,7 @@ public void invokeRemoteService( .responseHandler( new MLSdkAsyncHttpResponseHandler( executionContext, - actionListener, + ActionListener.runBefore(actionListener, storedContext::restore), parameters, connector, scriptService, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java index 81c58d3aaa..a1abee2089 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java @@ -56,7 +56,7 @@ @Getter public class AgenticConversationMemory implements Memory { - public static final String TYPE = "agentic_conversation"; + public static final String TYPE = "agentic_memory"; private static final String SESSION_ID_FIELD = "session_id"; private static final String CREATED_TIME_FIELD = "created_time"; @@ -181,6 +181,23 @@ public void update(String messageId, Map updateContent, ActionLi return; } + // Use retry mechanism for AOSS compatibility (high refresh latency) + updateWithRetry(messageId, updateContent, updateListener, 0); + } + + /** + * Update with retry mechanism to handle AOSS refresh latency (up to 10s) + * Uses exponential backoff: 500ms, 1s, 2s, 4s, 8s + */ + private void updateWithRetry( + String messageId, + Map updateContent, + ActionListener updateListener, + int attemptNumber + ) { + final int maxRetries = 5; + final long baseDelayMs = 500; + // Step 1: Get the existing working memory to retrieve current structured_data MLGetMemoryRequest getRequest = MLGetMemoryRequest .builder() @@ -247,8 +264,41 @@ public void update(String messageId, Map updateContent, ActionLi updateListener.onFailure(e); })); }, e -> { - log.error("Failed to get existing memory for update", e); - updateListener.onFailure(e); + // Check if it's a 404 (document not found) and we haven't exceeded max retries + boolean isNotFound = e.getMessage() != null && (e.getMessage().contains("404") || e.getMessage().contains("\"found\":false")); + + if (isNotFound && attemptNumber < maxRetries) { + // Calculate delay with exponential backoff + long delayMs = baseDelayMs * (1L << attemptNumber); // 500ms, 1s, 2s, 4s, 8s + + log + .warn( + "Document not found (attempt {}/{}), retrying after {}ms due to AOSS refresh latency. MessageId: {}", + attemptNumber + 1, + maxRetries, + delayMs, + messageId + ); + + // Schedule retry after delay + try { + Thread.sleep(delayMs); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + updateListener.onFailure(new RuntimeException("Retry interrupted", ie)); + return; + } + + // Retry + updateWithRetry(messageId, updateContent, updateListener, attemptNumber + 1); + } else { + if (attemptNumber >= maxRetries) { + log.error("Failed to get existing memory after {} retries. MessageId: {}", maxRetries, messageId, e); + } else { + log.error("Failed to get existing memory for update. MessageId: {}", messageId, e); + } + updateListener.onFailure(e); + } })); } @@ -301,9 +351,9 @@ private List parseSearchResponseToInteractions(SearchResponse searchRes String input = (String) structuredData.get("input"); String response = (String) structuredData.get("response"); - // Extract timestamps - Long createdTimeMs = (Long) sourceMap.get("created_time"); - Long updatedTimeMs = (Long) sourceMap.get("last_updated_time"); + // Extract timestamps - handle both Long and Double from OpenSearch + Long createdTimeMs = convertToLong(sourceMap.get("created_time")); + Long updatedTimeMs = convertToLong(sourceMap.get("last_updated_time")); // Parse create_time from structured_data if available String createTimeStr = (String) structuredData.get("create_time"); @@ -564,4 +614,20 @@ public void create(String memoryId, String memoryContainerId, ActionListener documentSource = indexRequest.sourceAsMap(); - - RemoteStorageHelper.writeDocument(connectorId, indexName, documentSource, client, ActionListener.wrap(response -> { - listener.onResponse(response); - }, listener::onFailure)); + RemoteStorageHelper.writeDocument(connectorId, indexName, documentSource, client, listener); } catch (Exception e) { log.error("Failed to index data to remote storage", e); listener.onFailure(e); From 14b2206d4aea7f82ee0c9432d0f5dff83f4880a9 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Sat, 25 Oct 2025 12:01:25 -0700 Subject: [PATCH 12/58] extending memory refactoring to PER agent (#4348) Signed-off-by: Dhrubo Saha --- .../org/opensearch/ml/common/MLAgentType.java | 2 +- .../opensearch/ml/common/MLMemoryType.java | 24 ++++++ .../opensearch/ml/common/agent/MLAgent.java | 2 +- .../transport/agent/MLAgentUpdateInput.java | 5 +- .../ml/common/MLAgentTypeTests.java | 4 +- .../agent/MLAgentUpdateInputTest.java | 2 +- .../algorithms/agent/MLAgentExecutor.java | 23 ++++-- .../algorithms/agent/MLChatAgentRunner.java | 3 +- .../MLPlanExecuteAndReflectAgentRunner.java | 75 +++++++++---------- .../memory/AgenticConversationMemory.java | 3 +- .../memory/ConversationIndexMemory.java | 3 +- 11 files changed, 91 insertions(+), 55 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/MLMemoryType.java diff --git a/common/src/main/java/org/opensearch/ml/common/MLAgentType.java b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java index 2dd2614634..04a4b72014 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLAgentType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLAgentType.java @@ -20,7 +20,7 @@ public static MLAgentType from(String value) { try { return MLAgentType.valueOf(value.toUpperCase(Locale.ROOT)); } catch (Exception e) { - throw new IllegalArgumentException("Wrong Agent type"); + throw new IllegalArgumentException(value + " is not a valid Agent Type"); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java b/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java new file mode 100644 index 0000000000..31939ce1ca --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import java.util.Locale; + +public enum MLMemoryType { + CONVERSATION_INDEX, + AGENTIC_MEMORY; + + public static MLMemoryType from(String value) { + if (value != null) { + try { + return MLMemoryType.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong Memory type"); + } + } + return null; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index b66a23f11e..73dcd3a7a8 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -113,7 +113,7 @@ private void validate() { String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH) ); } - validateMLAgentType(type); + MLAgentType.from(type); if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null) { throw new IllegalArgumentException("We need model information for the conversational agent type"); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java index 9a0d6002fd..e85b3f4bdc 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInput.java @@ -26,6 +26,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; @@ -383,9 +384,7 @@ private void validate() { String.format("Agent name cannot be empty or exceed max length of %d characters", MLAgent.AGENT_NAME_MAX_LENGTH) ); } - if (memoryType != null && !memoryType.equals("conversation_index")) { - throw new IllegalArgumentException(String.format("Invalid memory type: %s", memoryType)); - } + MLMemoryType.from(memoryType); if (tools != null) { Set toolNames = new HashSet<>(); for (MLToolSpec toolSpec : tools) { diff --git a/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java b/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java index ee15ca95fd..05f37c4992 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLAgentTypeTests.java @@ -44,14 +44,14 @@ public void testFromWithMixedCase() { public void testFromWithInvalidType() { // This should throw an IllegalArgumentException exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Wrong Agent type"); + exceptionRule.expectMessage(" is not a valid Agent Type"); MLAgentType.from("INVALID_TYPE"); } @Test public void testFromWithEmptyString() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Wrong Agent type"); + exceptionRule.expectMessage(" is not a valid Agent Type"); // This should also throw an IllegalArgumentException MLAgentType.from(""); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java index 72eb035279..084f95d137 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateInputTest.java @@ -94,7 +94,7 @@ public void testValidationWithInvalidMemoryType() { IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> { MLAgentUpdateInput.builder().agentId("test-agent-id").name("test-agent").memoryType("invalid_type").build(); }); - assertEquals("Invalid memory type: invalid_type", e.getMessage()); + assertEquals("Wrong Memory type", e.getMessage()); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 844c4b6136..11b3a5b976 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -17,7 +17,6 @@ import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE; import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createMemoryParams; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; import java.security.AccessController; import java.security.PrivilegedActionException; @@ -48,6 +47,7 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLAgentType; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; @@ -245,9 +245,10 @@ public void execute(Input input, ActionListener listener, TransportChann } if (memorySpec != null && memorySpec.getType() != null - && memoryFactoryMap.containsKey(memorySpec.getType()) + && memoryFactoryMap.containsKey(MLMemoryType.from(memorySpec.getType()).name()) && (memoryId == null || parentInteractionId == null)) { - Memory.Factory> memoryFactory = memoryFactoryMap.get(memorySpec.getType()); + Memory.Factory> memoryFactory = memoryFactoryMap + .get(MLMemoryType.from(memorySpec.getType()).name()); Map memoryParams = createMemoryParams(question, memoryId, appType, mlAgent); memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { @@ -299,14 +300,24 @@ public void execute(Input input, ActionListener listener, TransportChann } else { // For existing conversations, create memory instance using factory if (memorySpec != null && memorySpec.getType() != null) { + String memoryType = MLMemoryType.from(memorySpec.getType()).name(); + Memory.Factory> memoryFactory = memoryFactoryMap.get(memoryType); + ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap .get(memorySpec.getType()); - if (factory != null) { + if (memoryFactory != null) { // memoryId exists, so create returns an object with existing memory, therefore name can // be null - factory + Map memoryParams = createMemoryParams( + question, + memoryId, + appType, + mlAgent + ); + + memoryFactory .create( - Map.of(MEMORY_ID, memoryId, APP_TYPE, appType), + memoryParams, ActionListener .wrap( createdMemory -> executeAgent( diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 9ee9bd7658..a8b7b54d71 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -58,6 +58,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; @@ -177,7 +178,7 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener functionCalling.configure(params); } - String memoryType = mlAgent.getMemory().getType(); + String memoryType = MLMemoryType.from(mlAgent.getMemory().getType()).name(); String memoryId = params.get(MLAgentExecutor.MEMORY_ID); String appType = mlAgent.getAppType(); String title = params.get(MLAgentExecutor.QUESTION); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 96d314503c..2a3673824c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -16,6 +16,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.cleanUpResource; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createMemoryParams; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs; @@ -32,9 +33,6 @@ import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.FINAL_RESULT_RESPONSE_INSTRUCTIONS; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLANNER_RESPONSIBILITY; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; -import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; import java.util.ArrayList; import java.util.HashMap; @@ -51,6 +49,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; @@ -285,42 +284,42 @@ public void run(MLAgent mlAgent, Map apiParams, ActionListenerwrap(memory -> { - memory.getMessages(messageHistoryLimit, ActionListener.>wrap(interactions -> { - List completedSteps = new ArrayList<>(); - for (Interaction interaction : interactions) { - String question = interaction.getInput(); - String response = interaction.getResponse(); - - if (Strings.isNullOrEmpty(response)) { - continue; - } - - completedSteps.add(question); - completedSteps.add(response); - } + // ConversationIndexMemory.Factory conversationIndexMemoryFactory = (ConversationIndexMemory.Factory) + // memoryFactoryMap.get(memoryType); + + Memory.Factory> memoryFactory = memoryFactoryMap.get(memoryType); + Map memoryParams = createMemoryParams(apiParams.get(USER_PROMPT_FIELD), memoryId, appType, mlAgent); + memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { + memory.getMessages(messageHistoryLimit, ActionListener.>wrap(interactions -> { + List completedSteps = new ArrayList<>(); + for (Interaction interaction : interactions) { + String question = interaction.getInput(); + String response = interaction.getResponse(); + + if (Strings.isNullOrEmpty(response)) { + continue; + } - if (!completedSteps.isEmpty()) { - addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD); - usePlannerWithHistoryPromptTemplate(allParams); - } + completedSteps.add(question); + completedSteps.add(response); + } - setToolsAndRunAgent(mlAgent, allParams, completedSteps, memory, memory.getConversationId(), listener); - }, e -> { - log.error("Failed to get chat history", e); - listener.onFailure(e); - })); - }, listener::onFailure) - ); + if (!completedSteps.isEmpty()) { + addSteps(completedSteps, allParams, COMPLETED_STEPS_FIELD); + usePlannerWithHistoryPromptTemplate(allParams); + } + + setToolsAndRunAgent(mlAgent, allParams, completedSteps, memory, memory.getId(), listener); + }, e -> { + log.error("Failed to get chat history", e); + listener.onFailure(e); + })); + }, listener::onFailure)); } private void setToolsAndRunAgent( @@ -412,7 +411,7 @@ private void executePlanningLoop( if (parseLLMOutput.get(RESULT_FIELD) != null) { String finalResult = (String) parseLLMOutput.get(RESULT_FIELD); saveAndReturnFinalResult( - (ConversationIndexMemory) memory, + memory, parentInteractionId, allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD), allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD), @@ -512,7 +511,7 @@ private void executePlanningLoop( completedSteps.add(String.format("\nStep %d Result: %s\n", stepsExecuted + 1, results.get(STEP_RESULT_FIELD))); saveTraceData( - (ConversationIndexMemory) memory, + memory, memory.getType(), stepToExecute, results.get(STEP_RESULT_FIELD), @@ -636,7 +635,7 @@ void addSteps(List steps, Map allParams, String field) { @VisibleForTesting void saveAndReturnFinalResult( - ConversationIndexMemory memory, + Memory memory, String parentInteractionId, String reactAgentMemoryId, String reactParentInteractionId, @@ -651,9 +650,9 @@ void saveAndReturnFinalResult( updateContent.put(INTERACTIONS_INPUT_FIELD, input); } - memory.getMemoryManager().updateInteraction(parentInteractionId, updateContent, ActionListener.wrap(res -> { + memory.update(parentInteractionId, updateContent, ActionListener.wrap(res -> { List finalModelTensors = createModelTensors( - memory.getConversationId(), + memory.getId(), parentInteractionId, reactAgentMemoryId, reactParentInteractionId diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java index a1abee2089..3b27275afe 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java @@ -20,6 +20,7 @@ import org.opensearch.core.common.Strings; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.memory.Message; @@ -56,7 +57,7 @@ @Getter public class AgenticConversationMemory implements Memory { - public static final String TYPE = "agentic_memory"; + public static final String TYPE = MLMemoryType.AGENTIC_MEMORY.name(); private static final String SESSION_ID_FIELD = "session_id"; private static final String CREATED_TIME_FIELD = "created_time"; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java index e8e5f87a5a..fa861d7ac0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/ConversationIndexMemory.java @@ -13,6 +13,7 @@ import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; +import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.memory.Message; import org.opensearch.ml.engine.indices.MLIndicesHandler; @@ -26,7 +27,7 @@ @Log4j2 @Getter public class ConversationIndexMemory implements Memory { - public static final String TYPE = "conversation_index"; + public static final String TYPE = MLMemoryType.CONVERSATION_INDEX.name(); public static final String CONVERSATION_ID = "conversation_id"; public static final String FINAL_ANSWER = "final_answer"; public static final String CREATED_TIME = "created_time"; From c5ad5c3a33a6b0c3f5ee19379d08915df5850e21 Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Wed, 29 Oct 2025 00:02:07 -0700 Subject: [PATCH 13/58] Merge agent-revamp into feature/memory as a single commit Signed-off-by: Pavan Yekbote --- .../ml/common/agent/AgentInput.java | 555 ++++++++++++++++++ .../ml/common/agent/AgentInputProcessor.java | 199 +++++++ .../ml/common/agent/AgentModelService.java | 78 +++ .../agent/BedrockConverseModelProvider.java | 279 +++++++++ .../ml/common/agent/ContentBlock.java | 72 +++ .../ml/common/agent/ContentType.java | 16 + .../ml/common/agent/DocumentContent.java | 22 + .../ml/common/agent/ImageContent.java | 22 + .../opensearch/ml/common/agent/InputType.java | 16 + .../ml/common/agent/InputValidator.java | 288 +++++++++ .../opensearch/ml/common/agent/MLAgent.java | 39 ++ .../ml/common/agent/MLAgentModelSpec.java | 140 +++++ .../opensearch/ml/common/agent/Message.java | 23 + .../ml/common/agent/ModelProvider.java | 101 ++++ .../ml/common/agent/ModelProviderFactory.java | 28 + .../ml/common/agent/ModelProviderType.java | 34 ++ .../ml/common/agent/SourceType.java | 14 + .../ml/common/agent/VideoContent.java | 22 + .../connector/ConnectorClientConfig.java | 2 + .../input/execute/agent/AgentMLInput.java | 49 ++ .../algorithms/agent/MLAgentExecutor.java | 54 +- .../algorithms/agent/MLChatAgentRunner.java | 14 + .../agents/TransportRegisterAgentAction.java | 46 ++ 23 files changed, 2109 insertions(+), 4 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/AgentInput.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/AgentInputProcessor.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/AgentModelService.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/BedrockConverseModelProvider.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/ContentBlock.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/ContentType.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/DocumentContent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/ImageContent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/InputType.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/InputValidator.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/MLAgentModelSpec.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/Message.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/ModelProvider.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/ModelProviderFactory.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/ModelProviderType.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/SourceType.java create mode 100644 common/src/main/java/org/opensearch/ml/common/agent/VideoContent.java diff --git a/common/src/main/java/org/opensearch/ml/common/agent/AgentInput.java b/common/src/main/java/org/opensearch/ml/common/agent/AgentInput.java new file mode 100644 index 0000000000..dc2aeccfbd --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/AgentInput.java @@ -0,0 +1,555 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Represents standardized agent input that can handle different input formats: + * - Plain text (String) + * - Multi-modal content blocks (List) + * - Message-based conversations (List) + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class AgentInput implements Writeable { + // String, List, or List + private Object input; + + /** + * Constructor for stream input deserialization. + * Supports all input types including images, videos, and documents. + */ + public AgentInput(StreamInput in) throws IOException { + InputType inputType = InputType.valueOf(in.readString()); + switch (inputType) { + case TEXT: + this.input = in.readString(); + break; + case CONTENT_BLOCKS: + this.input = readContentBlocksList(in); + break; + case MESSAGES: + this.input = readMessagesList(in); + break; + default: + throw new IOException("Unsupported input type: " + inputType); + } + } + + /** + * Reads a list of ContentBlocks from stream input. + */ + private List readContentBlocksList(StreamInput in) throws IOException { + int size = in.readInt(); + List contentBlocks = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + ContentBlock block = readContentBlock(in); + contentBlocks.add(block); + } + + return contentBlocks; + } + + /** + * Reads a single ContentBlock from stream input. + */ + private ContentBlock readContentBlock(StreamInput in) throws IOException { + ContentType type = ContentType.valueOf(in.readString()); + ContentBlock block = new ContentBlock(); + block.setType(type); + + switch (type) { + case TEXT: + block.setText(in.readString()); + break; + case IMAGE: + block.setImage(readImageContent(in)); + break; + case VIDEO: + block.setVideo(readVideoContent(in)); + break; + case DOCUMENT: + block.setDocument(readDocumentContent(in)); + break; + } + + return block; + } + + /** + * Reads ImageContent from stream input. + */ + private ImageContent readImageContent(StreamInput in) throws IOException { + SourceType sourceType = SourceType.valueOf(in.readString()); + String format = in.readString(); + String data = in.readString(); + + return new ImageContent(sourceType, format, data); + } + + /** + * Reads VideoContent from stream input. + */ + private VideoContent readVideoContent(StreamInput in) throws IOException { + SourceType sourceType = SourceType.valueOf(in.readString()); + String format = in.readString(); + String data = in.readString(); + + return new VideoContent(sourceType, format, data); + } + + /** + * Reads DocumentContent from stream input. + */ + private DocumentContent readDocumentContent(StreamInput in) throws IOException { + SourceType sourceType = SourceType.valueOf(in.readString()); + String format = in.readString(); + String data = in.readString(); + + return new DocumentContent(sourceType, format, data); + } + + /** + * Reads a list of Messages from stream input. + */ + private List readMessagesList(StreamInput in) throws IOException { + int size = in.readInt(); + List messages = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + Message message = readMessage(in); + messages.add(message); + } + + return messages; + } + + /** + * Reads a single Message from stream input. + */ + private Message readMessage(StreamInput in) throws IOException { + String role = in.readString(); + List content = readContentBlocksList(in); + + Message message = new Message(); + message.setRole(role); + message.setContent(content); + return message; + } + + /** + * Constructor for XContent parsing. + * Supports the simplified format where everything is under "input" field. + */ + public AgentInput(XContentParser parser) throws IOException { + XContentParser.Token currentToken = parser.currentToken(); + if (currentToken == XContentParser.Token.VALUE_STRING) { + // Plain text: {"input": "hi does this work"} + this.input = parser.text(); + } else if (currentToken == XContentParser.Token.START_ARRAY) { + // Array format: could be content blocks or messages + this.input = parseInputArray(parser); + } else { + throw new IllegalArgumentException("Invalid input format. Expected string or array."); + } + } + + /** + * Parses an array input and determines if it's content blocks or messages. + */ + private Object parseInputArray(XContentParser parser) throws IOException { + List items = new ArrayList<>(); + + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + Object item = parseArrayItem(parser); + items.add(item); + } + + // Determine if this is messages or content blocks based on first item + if (!items.isEmpty()) { + Object firstItem = items.getFirst(); + if (firstItem instanceof Message) { + List messages = new ArrayList<>(); + for (Object item : items) { + messages.add((Message) item); + } + return messages; + } else if (firstItem instanceof ContentBlock) { + List contentBlocks = new ArrayList<>(); + for (Object item : items) { + contentBlocks.add((ContentBlock) item); + } + return contentBlocks; + } + } + + return items; + } + + /** + * Parses a single item from the input array. + * Determines if it's a Message or ContentBlock based on structure. + */ + private Object parseArrayItem(XContentParser parser) throws IOException { + Map itemMap = new HashMap<>(); + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case "role": + // This indicates it's a Message + itemMap.put("role", parser.text()); + break; + case "content": + // Parse content array for messages + itemMap.put("content", parseContentArray(parser)); + break; + case "type": + // This indicates it's a ContentBlock + itemMap.put("type", parser.text()); + break; + case "text": + itemMap.put("text", parser.text()); + break; + case "source": + itemMap.put("source", parseSource(parser)); + break; + default: + // Store other fields as-is + itemMap.put(fieldName, parseValue(parser)); + break; + } + } + + // Determine if this is a Message or ContentBlock + if (itemMap.containsKey("role")) { + return createMessage(itemMap); + } + + if (itemMap.containsKey("type")) { + return createContentBlock(itemMap); + } + + throw new IllegalArgumentException("Invalid item format. Must have 'role' (for messages) or 'type' (for content blocks)."); + } + + /** + * Parses content array for messages. + */ + @SuppressWarnings("unchecked") + private List parseContentArray(XContentParser parser) throws IOException { + List contentBlocks = new ArrayList<>(); + + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + ContentBlock contentBlock = (ContentBlock) parseArrayItem(parser); + contentBlocks.add(contentBlock); + } + + return contentBlocks; + } + + /** + * Parses source object for media content. + */ + private Map parseSource(XContentParser parser) throws IOException { + Map source = new HashMap<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + source.put(fieldName, parseValue(parser)); + } + + return source; + } + + /** + * Parses a generic value from the parser. + */ + private Object parseValue(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + + switch (token) { + case VALUE_STRING: + return parser.text(); + case VALUE_NUMBER: + return parser.numberValue(); + case VALUE_BOOLEAN: + return parser.booleanValue(); + case VALUE_NULL: + return null; + case START_OBJECT: + Map map = new HashMap<>(); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + map.put(fieldName, parseValue(parser)); + } + return map; + case START_ARRAY: + List list = new ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + list.add(parseValue(parser)); + } + return list; + default: + throw new IllegalArgumentException("Unexpected token: " + token); + } + } + + /** + * Creates a Message object from parsed data. + */ + @SuppressWarnings("unchecked") + private Message createMessage(Map itemMap) { + String role = (String) itemMap.get("role"); + List content = (List) itemMap.get("content"); + + Message message = new Message(); + message.setRole(role); + message.setContent(content); + return message; + } + + /** + * Creates a ContentBlock object from parsed data. + */ + @SuppressWarnings("unchecked") + private ContentBlock createContentBlock(Map itemMap) { + String type = (String) itemMap.get("type"); + ContentType contentType = ContentType.valueOf(type.toUpperCase()); + + ContentBlock contentBlock = new ContentBlock(); + contentBlock.setType(contentType); + + switch (contentType) { + case TEXT: + contentBlock.setText((String) itemMap.get("text")); + break; + case IMAGE: + Map source = (Map) itemMap.get("source"); + ImageContent imageContent = createImageContent(source); + contentBlock.setImage(imageContent); + break; + case VIDEO: + Map videoSource = (Map) itemMap.get("source"); + VideoContent videoContent = createVideoContent(videoSource); + contentBlock.setVideo(videoContent); + break; + case DOCUMENT: + Map docSource = (Map) itemMap.get("source"); + DocumentContent documentContent = createDocumentContent(docSource); + contentBlock.setDocument(documentContent); + break; + } + + return contentBlock; + } + + /** + * Creates ImageContent from source data. + */ + private ImageContent createImageContent(Map source) { + String format = (String) source.get("format"); + String type = (String) source.get("type"); + String data = (String) source.get("data"); + + SourceType sourceType = SourceType.valueOf(type.toUpperCase()); + + ImageContent imageContent = new ImageContent(); + imageContent.setFormat(format); + imageContent.setType(sourceType); + imageContent.setData(data); + return imageContent; + } + + /** + * Creates VideoContent from source data. + */ + private VideoContent createVideoContent(Map source) { + String format = (String) source.get("format"); + String type = (String) source.get("type"); + String data = (String) source.get("data"); + + SourceType sourceType = SourceType.valueOf(type.toUpperCase()); + + VideoContent videoContent = new VideoContent(); + videoContent.setFormat(format); + videoContent.setType(sourceType); + videoContent.setData(data); + return videoContent; + } + + /** + * Creates DocumentContent from source data. + */ + private DocumentContent createDocumentContent(Map source) { + String format = (String) source.get("format"); + String type = (String) source.get("type"); + String data = (String) source.get("data"); + + SourceType sourceType = SourceType.valueOf(type.toUpperCase()); + + DocumentContent documentContent = new DocumentContent(); + documentContent.setFormat(format); + documentContent.setType(sourceType); + documentContent.setData(data); + return documentContent; + } + + @Override + public void writeTo(StreamOutput out) throws IllegalArgumentException, IOException { + InputType inputType = getInputType(); + out.writeString(inputType.name()); + + switch (inputType) { + case TEXT: + out.writeString((String) input); + break; + case CONTENT_BLOCKS: + @SuppressWarnings("unchecked") + List contentBlocks = (List) input; + writeContentBlocksList(out, contentBlocks); + break; + case MESSAGES: + @SuppressWarnings("unchecked") + List messages = (List) input; + writeMessagesList(out, messages); + break; + default: + throw new IllegalArgumentException("Unsupported input type: " + inputType); + } + } + + /** + * Writes a list of ContentBlocks to stream output. + */ + private void writeContentBlocksList(StreamOutput out, List contentBlocks) throws IOException { + out.writeInt(contentBlocks.size()); + for (ContentBlock block : contentBlocks) { + writeContentBlock(out, block); + } + } + + /** + * Writes a single ContentBlock to stream output. + */ + private void writeContentBlock(StreamOutput out, ContentBlock block) throws IOException { + out.writeString(block.getType().name()); + + switch (block.getType()) { + case TEXT: + out.writeString(block.getText()); + break; + case IMAGE: + writeImageContent(out, block.getImage()); + break; + case VIDEO: + writeVideoContent(out, block.getVideo()); + break; + case DOCUMENT: + writeDocumentContent(out, block.getDocument()); + break; + } + } + + /** + * Writes ImageContent to stream output. + */ + private void writeImageContent(StreamOutput out, ImageContent image) throws IOException { + out.writeString(image.getType().name()); + out.writeString(image.getFormat()); + out.writeString(image.getData()); + } + + /** + * Writes VideoContent to stream output. + */ + private void writeVideoContent(StreamOutput out, VideoContent video) throws IOException { + out.writeString(video.getType().name()); + out.writeString(video.getFormat()); + out.writeString(video.getData()); + } + + /** + * Writes DocumentContent to stream output. + */ + private void writeDocumentContent(StreamOutput out, DocumentContent document) throws IOException { + out.writeString(document.getType().name()); + out.writeString(document.getFormat()); + out.writeString(document.getData()); + } + + /** + * Writes a list of Messages to stream output. + */ + private void writeMessagesList(StreamOutput out, List messages) throws IOException { + out.writeInt(messages.size()); + + for (Message message : messages) { + writeMessage(out, message); + } + } + + /** + * Writes a single Message to stream output. + */ + private void writeMessage(StreamOutput out, Message message) throws IOException { + out.writeString(message.getRole()); + writeContentBlocksList(out, message.getContent()); + } + + /** + * Determines the type of input based on the input object. + * @return InputType enum value indicating the format of the input + */ + public InputType getInputType() throws IllegalArgumentException { + if (input instanceof String) { + return InputType.TEXT; + } + + if (input instanceof List list) { + if (!list.isEmpty()) { + Object firstElement = list.getFirst(); + if (firstElement instanceof ContentBlock) { + return InputType.CONTENT_BLOCKS; + } + + if (firstElement instanceof Message) { + return InputType.MESSAGES; + } + } + } + + throw new IllegalArgumentException("Input type not supported: " + input); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/AgentInputProcessor.java b/common/src/main/java/org/opensearch/ml/common/agent/AgentInputProcessor.java new file mode 100644 index 0000000000..033a8d0271 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/AgentInputProcessor.java @@ -0,0 +1,199 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import java.util.List; + +import lombok.extern.log4j.Log4j2; + +/** + * Utility class for validating standardized agent input formats. + * The AgentInput itself is already standardized - this validator just validates it + * and ensures it's ready to be passed to ModelProviders for conversion to their + * specific request body formats. + */ +@Log4j2 +public class AgentInputProcessor { + + // Private constructor to prevent instantiation + private AgentInputProcessor() { + throw new UnsupportedOperationException("Utility class cannot be instantiated"); + } + + /** + * Validates the standardized AgentInput. + * The AgentInput is passed through after validation - ModelProviders will + * handle the conversion to their specific request body parameters. + * + * @param agentInput the standardized agent input + * @throws IllegalArgumentException if input is invalid + */ + public static void validateInput(AgentInput agentInput) { + if (agentInput == null || agentInput.getInput() == null) { + throw new IllegalArgumentException("AgentInput and its input field cannot be null"); + } + + InputType type = agentInput.getInputType(); + + switch (type) { + case TEXT: + validateTextInput((String) agentInput.getInput()); + break; + case CONTENT_BLOCKS: + @SuppressWarnings("unchecked") + List blocks = (List) agentInput.getInput(); + validateContentBlocks(blocks); + break; + case MESSAGES: + @SuppressWarnings("unchecked") + List messages = (List) agentInput.getInput(); + validateMessages(messages); + break; + default: + throw new IllegalArgumentException("Unsupported input type: " + type); + } + } + + /** + * Validates simple text input. + * + * @param text the text input + * @throws IllegalArgumentException if text is invalid + */ + private static void validateTextInput(String text) { + if (text == null || text.trim().isEmpty()) { + throw new IllegalArgumentException("Text input cannot be null or empty"); + } + } + + /** + * Validates multi-modal content blocks. + * + * @param blocks the list of content blocks + * @throws IllegalArgumentException if content blocks are invalid + */ + private static void validateContentBlocks(List blocks) { + if (blocks == null || blocks.isEmpty()) { + throw new IllegalArgumentException("Content blocks cannot be null or empty"); + } + + for (ContentBlock block : blocks) { + if (block.getType() == null) { + throw new IllegalArgumentException("Content block type cannot be null"); + } + + switch (block.getType()) { + case TEXT: + if (block.getText() == null || block.getText().trim().isEmpty()) { + throw new IllegalArgumentException("Text content block cannot have null or empty text"); + } + break; + case IMAGE: + if (block.getImage() == null) { + throw new IllegalArgumentException("Image content block must have image data"); + } + break; + case DOCUMENT: + if (block.getDocument() == null) { + throw new IllegalArgumentException("Document content block must have document data"); + } + break; + case VIDEO: + if (block.getVideo() == null) { + throw new IllegalArgumentException("Video content block must have video data"); + } + break; + default: + throw new IllegalArgumentException("Unsupported content block type: " + block.getType()); + } + } + } + + /** + * Validates message-based conversation input. + * + * @param messages the list of messages + * @throws IllegalArgumentException if messages are invalid + */ + private static void validateMessages(List messages) { + if (messages == null || messages.isEmpty()) { + throw new IllegalArgumentException("Messages cannot be null or empty"); + } + + for (Message message : messages) { + if (message.getRole() == null || message.getRole().trim().isEmpty()) { + throw new IllegalArgumentException("Message role cannot be null or empty"); + } + + if (message.getContent() == null || message.getContent().isEmpty()) { + throw new IllegalArgumentException("Message content cannot be null or empty"); + } + + // Validate each content block in the message + validateContentBlocks(message.getContent()); + } + } + + /** + * Extracts question text from AgentInput for prompt template usage. + * This provides the text that will be used in prompt templates that reference $parameters.question. + */ + public static String extractQuestionText(AgentInput agentInput) { + validateInput(agentInput); + return switch (agentInput.getInputType()) { + case TEXT -> (String) agentInput.getInput(); + case CONTENT_BLOCKS -> { + // For content blocks, extract and combine text content + @SuppressWarnings("unchecked") + List blocks = (List) agentInput.getInput(); + yield extractTextFromContentBlocks(blocks); + } + case MESSAGES -> { + // For messages, extract the last user message text + @SuppressWarnings("unchecked") + List messages = (List) agentInput.getInput(); + yield extractTextFromMessages(messages); + } + default -> throw new IllegalArgumentException("Unsupported input type: " + agentInput.getInputType()); + }; + } + + /** + * Extracts text content from content blocks for human-readable display. + * Ignores non text blocks + * @throws IllegalArgumentException if content blocks are invalid[ + */ + private static String extractTextFromContentBlocks(List blocks) { + if (blocks == null || blocks.isEmpty()) { + throw new IllegalArgumentException("Content blocks cannot be null or empty"); + } + + StringBuilder textBuilder = new StringBuilder(); + for (ContentBlock block : blocks) { + if (block.getType() == ContentType.TEXT) { + String text = block.getText(); + if (text != null && !text.trim().isEmpty()) { + textBuilder.append(text.trim()); + textBuilder.append("\n"); + } + } + } + + return textBuilder.toString(); + } + + /** + * Extracts text content from last message. + */ + private static String extractTextFromMessages(List messages) { + if (messages == null || messages.isEmpty()) { + throw new IllegalArgumentException("Messages cannot be null or empty"); + } + + Message message = messages.getLast(); + return extractTextFromContentBlocks(message.getContent()); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/AgentModelService.java b/common/src/main/java/org/opensearch/ml/common/agent/AgentModelService.java new file mode 100644 index 0000000000..907e03c0ed --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/AgentModelService.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; + +import lombok.extern.log4j.Log4j2; + +/** + * Service class for handling model creation during agent registration + */ +@Log4j2 +public class AgentModelService { + + /** + * Creates a model input from the agent model specification + * @param modelSpec the model specification from agent registration + * @return MLRegisterModelInput ready for model registration + * @throws IllegalArgumentException if model provider is not supported + */ + public static MLRegisterModelInput createModelFromSpec(MLAgentModelSpec modelSpec) { + validateModelSpec(modelSpec); + ModelProvider provider = ModelProviderFactory.getProvider(modelSpec.getModelProvider()); + + Connector connector = provider.createConnector(modelSpec.getModelId(), modelSpec.getCredential(), modelSpec.getModelParameters()); + + return provider.createModelInput(modelSpec.getModelId(), connector, modelSpec.getModelParameters()); + } + + /** + * Infers the LLM interface from model provider for function calling + * @param modelProvider the model provider string + * @return the corresponding LLM interface string, or null if not supported + */ + public static String inferLLMInterface(String modelProvider) { + if (modelProvider == null) { + return null; + } + + try { + ModelProvider provider = ModelProviderFactory.getProvider(modelProvider); + return provider.getLLMInterface(); + } catch (Exception e) { + log.error("Failed to infer LLM interface", e); + return null; + } + } + + /** + * Validates that the model specification is complete and valid + * @param modelSpec the model specification to validate + * @throws IllegalArgumentException if validation fails + */ + private static void validateModelSpec(MLAgentModelSpec modelSpec) { + if (modelSpec == null) { + throw new IllegalArgumentException("Model specification not found"); + } + + if (modelSpec.getModelId() == null || modelSpec.getModelId().trim().isEmpty()) { + throw new IllegalArgumentException("model_id cannot be null or empty"); + } + + if (modelSpec.getModelProvider() == null || modelSpec.getModelProvider().trim().isEmpty()) { + throw new IllegalArgumentException("model_provider cannot be null or empty"); + } + + // Validate that the provider type is supported + try { + ModelProviderType.from(modelSpec.getModelProvider()); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Unsupported model provider: " + modelSpec.getModelProvider()); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/BedrockConverseModelProvider.java b/common/src/main/java/org/opensearch/ml/common/agent/BedrockConverseModelProvider.java new file mode 100644 index 0000000000..308f91af92 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/BedrockConverseModelProvider.java @@ -0,0 +1,279 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.text.StringEscapeUtils; +import org.apache.commons.text.StringSubstitutor; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.AwsConnector; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.ConnectorClientConfig; +import org.opensearch.ml.common.connector.ConnectorProtocols; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.utils.ToolUtils; + +/** + * Model provider for Bedrock Converse API. + * + * This provider uses template-based parameter substitution with StringSubstitutor + * to create the request body. Different input types (text, content blocks, messages) + * use different body templates that are parameterized and filled in at runtime. + * + * The main request body template uses ${parameters.body} which gets populated + * with the appropriate message structure based on the input type. + * + * Template parameters are uniquely named to avoid conflicts: + * - Text input: ${parameters.user_text} + * - Content blocks: ${parameters.content_array} + * - Messages: ${parameters.messages_array} + * - Content types use prefixed parameters: ${parameters.content_text}, ${parameters.image_format}, etc. + * - Source types are dynamically mapped: BASE64 → "bytes", URL → "s3Location" + * + * All parameters consistently use the ${parameters.} prefix for uniformity. + */ +// todo: refactor the processing so providers have to only provide the constants +public class BedrockConverseModelProvider extends ModelProvider { + + private static final String DEFAULT_REGION = "us-east-1"; + + private static final String REQUEST_BODY_TEMPLATE = "{\"system\": [{\"text\": \"${parameters.system_prompt}\"}], " + + "\"messages\": [${parameters._chat_history:-}${parameters.body}${parameters._interactions:-}]" + + "${parameters.tool_configs:-} }"; + + // Body templates for different input types + private static final String TEXT_INPUT_BODY_TEMPLATE = "{\"role\":\"user\",\"content\":[{\"text\":\"${parameters.user_text}\"}]}"; + + private static final String CONTENT_BLOCKS_BODY_TEMPLATE = "{\"role\":\"user\",\"content\":[${parameters.content_array}]}"; + + // Content block templates for multi-modal content + private static final String TEXT_CONTENT_TEMPLATE = "{\"text\":\"${parameters.content_text}\"}"; + + private static final String IMAGE_CONTENT_TEMPLATE = + "{\"image\":{\"format\":\"${parameters.image_format}\",\"source\":{\"${parameters.image_source_type}\":\"${parameters.image_data}\"}}}"; + + private static final String DOCUMENT_CONTENT_TEMPLATE = + "{\"document\":{\"format\":\"${parameters.doc_format}\",\"name\":\"${parameters.doc_name}\",\"source\":{\"${parameters.doc_source_type}\":\"${parameters.doc_data}\"}}}"; + + private static final String VIDEO_CONTENT_TEMPLATE = + "{\"video\":{\"format\":\"${parameters.video_format}\",\"source\":{\"${parameters.video_source_type}\":\"${parameters.video_data}\"}}}"; + + private static final String MESSAGE_TEMPLATE = "{\"role\":\"${parameters.msg_role}\",\"content\":[${parameters.msg_content_array}]}"; + + @Override + public Connector createConnector(String modelId, Map credential, Map modelParameters) { + Map parameters = new HashMap<>(); + parameters.put("region", DEFAULT_REGION); // Default region, can be overridden + parameters.put("service_name", "bedrock"); + parameters.put("model", modelId); + + // Override with any provided model parameters + if (modelParameters != null) { + parameters.putAll(modelParameters); + } + + Map headers = new HashMap<>(); + headers.put("content-type", "application/json"); + + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/converse") + .headers(headers) + .requestBody(REQUEST_BODY_TEMPLATE) + .build(); + + // Set agent connector to have default 3 retries + ConnectorClientConfig connectorClientConfig = new ConnectorClientConfig(); + connectorClientConfig.setMaxRetryTimes(3); + + return AwsConnector + .awsConnectorBuilder() + .name("Auto-generated Bedrock Converse connector for Agent") + .description("Auto-generated connector for Bedrock Converse API") + .version("1") + .protocol(ConnectorProtocols.AWS_SIGV4) + .parameters(parameters) + .credential(credential != null ? credential : new HashMap<>()) + .actions(List.of(predictAction)) + .connectorClientConfig(connectorClientConfig) + .build(); + } + + @Override + public MLRegisterModelInput createModelInput(String modelName, Connector connector, Map modelParameters) { + return MLRegisterModelInput + .builder() + .functionName(FunctionName.REMOTE) + .modelName("Auto-generated model for " + modelName) + .description("Auto-generated model for agent") + .connector(connector) + .build(); + } + + @Override + public String getLLMInterface() { + return "bedrock/converse/claude"; + } + + @Override + public Map mapTextInput(String text) { + Map parameters = new HashMap<>(); + + // Use StringSubstitutor for parameter replacement + Map templateParams = new HashMap<>(); + templateParams.put("user_text", StringEscapeUtils.escapeJson(text)); + + StringSubstitutor substitutor = new StringSubstitutor(templateParams, "${parameters.", "}"); + String body = substitutor.replace(TEXT_INPUT_BODY_TEMPLATE); + parameters.put("body", body); + + return parameters; + } + + @Override + public Map mapContentBlocks(List contentBlocks) { + Map parameters = new HashMap<>(); + + // Use StringSubstitutor for parameter replacement + String contentArray = buildContentArrayFromBlocks(contentBlocks); + Map templateParams = new HashMap<>(); + templateParams.put("content_array", contentArray); + + StringSubstitutor substitutor = new StringSubstitutor(templateParams, "${parameters.", "}"); + String body = substitutor.replace(CONTENT_BLOCKS_BODY_TEMPLATE); + parameters.put("body", body); + + return parameters; + } + + @Override + public Map mapMessages(List messages) { + Map parameters = new HashMap<>(); + String messagesString = buildMessagesArray(messages); + parameters.put("body", messagesString); + // todo: Merge function calling code into this class + // body is added to no_escape_params as the json constructed is a sequence of objects and not a valid json + // it becomes valid as REQUEST_BODY_TEMPLATE wraps this in an array + parameters.put(ToolUtils.NO_ESCAPE_PARAMS, "_chat_history,_tools,_interactions,tool_configs,body"); + return parameters; + } + + /** + * Builds content array from content blocks using templates for Bedrock Converse API. + * Supports text, image, document, and video content types. + */ + private String buildContentArrayFromBlocks(List blocks) { + if (blocks == null || blocks.isEmpty()) { + return ""; + } + + StringBuilder contentArray = new StringBuilder(); + boolean first = true; + for (ContentBlock block : blocks) { + if (!first) { + contentArray.append(","); + } + first = false; + + switch (block.getType()) { + case TEXT: + Map textParams = new HashMap<>(); + textParams.put("content_text", StringEscapeUtils.escapeJson(block.getText())); + StringSubstitutor textSubstitutor = new StringSubstitutor(textParams, "${parameters.", "}"); + contentArray.append(textSubstitutor.replace(TEXT_CONTENT_TEMPLATE)); + break; + case IMAGE: + ImageContent image = block.getImage(); + Map imageParams = new HashMap<>(); + imageParams.put("image_format", image.getFormat()); + imageParams.put("image_data", image.getData()); + // Map SourceType to Bedrock Converse API source type + String imageSourceType = mapSourceTypeToBedrock(image.getType()); + imageParams.put("image_source_type", imageSourceType); + StringSubstitutor imageSubstitutor = new StringSubstitutor(imageParams, "${parameters.", "}"); + contentArray.append(imageSubstitutor.replace(IMAGE_CONTENT_TEMPLATE)); + break; + case DOCUMENT: + DocumentContent document = block.getDocument(); + Map docParams = new HashMap<>(); + docParams.put("doc_format", document.getFormat()); + docParams.put("doc_name", "document"); + docParams.put("doc_data", document.getData()); + // Map SourceType to Bedrock Converse API source type + String docSourceType = mapSourceTypeToBedrock(document.getType()); + docParams.put("doc_source_type", docSourceType); + StringSubstitutor docSubstitutor = new StringSubstitutor(docParams, "${parameters.", "}"); + contentArray.append(docSubstitutor.replace(DOCUMENT_CONTENT_TEMPLATE)); + break; + case VIDEO: + VideoContent video = block.getVideo(); + Map videoParams = new HashMap<>(); + videoParams.put("video_format", video.getFormat()); + videoParams.put("video_data", video.getData()); + // Map SourceType to Bedrock Converse API source type + String videoSourceType = mapSourceTypeToBedrock(video.getType()); + videoParams.put("video_source_type", videoSourceType); + StringSubstitutor videoSubstitutor = new StringSubstitutor(videoParams, "${parameters.", "}"); + contentArray.append(videoSubstitutor.replace(VIDEO_CONTENT_TEMPLATE)); + break; + default: + // Skip unsupported content types + break; + } + } + + return contentArray.toString(); + } + + /** + * Builds messages array using templates for Bedrock Converse API. + * Converts messages to conversation history format, excluding the last user message + * which becomes the current input. + */ + private String buildMessagesArray(List messages) { + if (messages == null || messages.isEmpty()) { + return ""; + } + + StringBuilder messagesArray = new StringBuilder(); + boolean first = true; + for (Message message : messages) { + if (!first) { + messagesArray.append(","); + } + first = false; + + String contentArray = buildContentArrayFromBlocks(message.getContent()); + Map msgParams = new HashMap<>(); + msgParams.put("msg_role", message.getRole()); + msgParams.put("msg_content_array", contentArray); + StringSubstitutor msgSubstitutor = new StringSubstitutor(msgParams, "${parameters.", "}"); + messagesArray.append(msgSubstitutor.replace(MESSAGE_TEMPLATE)); + } + + return messagesArray.toString(); + } + + /** + * Maps SourceType to Bedrock Converse API source field names. + * + * @param sourceType the source type from content + * @return the corresponding Bedrock API source field name + */ + private String mapSourceTypeToBedrock(SourceType sourceType) { + if (sourceType == SourceType.URL) { + return "s3Location"; // Bedrock Converse API uses s3Location for URL-based content + } + + return "bytes"; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/ContentBlock.java b/common/src/main/java/org/opensearch/ml/common/agent/ContentBlock.java new file mode 100644 index 0000000000..ac0ff30801 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/ContentBlock.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Represents a content block that can contain different types of content (text, image, video, document). + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class ContentBlock implements Writeable { + private ContentType type; + private String text; // for text content + private ImageContent image; // for image content + private VideoContent video; // for video content + private DocumentContent document; // for document content + + /** + * Constructor for text content block. + */ + public ContentBlock(String text) { + this.type = ContentType.TEXT; + this.text = text; + } + + /** + * Constructor for image content block. + */ + public ContentBlock(ImageContent image) { + this.type = ContentType.IMAGE; + this.image = image; + } + + /** + * Constructor for video content block. + */ + public ContentBlock(VideoContent video) { + this.type = ContentType.VIDEO; + this.video = video; + } + + /** + * Constructor for document content block. + */ + public ContentBlock(DocumentContent document) { + this.type = ContentType.DOCUMENT; + this.document = document; + } + + // TODO: Add stream and XContent constructors when content classes support them + // For now, we'll implement a basic version for the POC + + @Override + public void writeTo(StreamOutput out) throws IOException { + // Basic implementation for POC - only support text content for now + out.writeString(type != null ? type.name() : ContentType.TEXT.name()); + out.writeOptionalString(text); + // TODO: Add support for other content types when their classes support serialization + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/ContentType.java b/common/src/main/java/org/opensearch/ml/common/agent/ContentType.java new file mode 100644 index 0000000000..52071bf86b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/ContentType.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +/** + * Enum representing the different types of content that can be included in a content block. + */ +public enum ContentType { + TEXT, + IMAGE, + VIDEO, + DOCUMENT +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/DocumentContent.java b/common/src/main/java/org/opensearch/ml/common/agent/DocumentContent.java new file mode 100644 index 0000000000..7e50ead141 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/DocumentContent.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Represents document content with type, format, and data fields. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class DocumentContent { + private SourceType type; + private String format; // "pdf", "docx", "txt", etc. + private String data; // URL or base64 data +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/ImageContent.java b/common/src/main/java/org/opensearch/ml/common/agent/ImageContent.java new file mode 100644 index 0000000000..b682c44236 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/ImageContent.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Represents image content with type, format, and data fields. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class ImageContent { + private SourceType type; + private String format; // "jpeg", "png", "gif", "webp" + private String data; // URL or base64 data +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/InputType.java b/common/src/main/java/org/opensearch/ml/common/agent/InputType.java new file mode 100644 index 0000000000..9e5ee4f382 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/InputType.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +/** + * Enum representing the different types of standardized agent input formats. + */ +public enum InputType { + TEXT, + CONTENT_BLOCKS, + MESSAGES, + UNKNOWN +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/InputValidator.java b/common/src/main/java/org/opensearch/ml/common/agent/InputValidator.java new file mode 100644 index 0000000000..31dca3df30 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/InputValidator.java @@ -0,0 +1,288 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import java.util.List; + +import org.opensearch.ml.common.exception.MLValidationException; + +/** + * Validates agent input formats and content to ensure they meet the required structure + * and contain valid data before processing. + */ +// ToDo: this validation is too strict, take a look at the validation logic and fix it +public class InputValidator { + + /** + * Validates an AgentInput object based on its detected input type. + * + * @param input the AgentInput to validate + * @throws MLValidationException if validation fails + */ + public void validateAgentInput(AgentInput input) throws MLValidationException { + if (input == null || input.getInput() == null) { + throw new MLValidationException("Input cannot be null"); + } + + InputType type = input.getInputType(); + switch (type) { + case CONTENT_BLOCKS: + validateContentBlocks((List) input.getInput()); + break; + case MESSAGES: + validateMessages((List) input.getInput()); + break; + case TEXT: + validateText((String) input.getInput()); + break; + case UNKNOWN: + default: + throw new MLValidationException("Invalid input format. Expected string, array of content blocks, or array of messages"); + } + } + + /** + * Validates an array of content blocks. + * + * @param blocks the content blocks to validate + * @throws MLValidationException if validation fails + */ + public void validateContentBlocks(List blocks) throws MLValidationException { + if (blocks == null || blocks.isEmpty()) { + throw new MLValidationException("Content blocks cannot be null or empty"); + } + + int index = 0; + for (ContentBlock block : blocks) { + try { + validateContentBlock(block); + } catch (MLValidationException e) { + throw new MLValidationException("Content block at index " + index + " is invalid: " + e.getMessage()); + } + index++; + } + } + + /** + * Validates an array of messages. + * + * @param messages the messages to validate + * @throws MLValidationException if validation fails + */ + public void validateMessages(List messages) throws MLValidationException { + if (messages == null || messages.isEmpty()) { + throw new MLValidationException("Messages cannot be null or empty"); + } + + int index = 0; + for (Message message : messages) { + try { + if (message == null) { + throw new MLValidationException("Message cannot be null"); + } + + if (message.getRole() == null || message.getRole().trim().isEmpty()) { + throw new MLValidationException("Message must have a non-empty role"); + } + + if (message.getContent() == null) { + throw new MLValidationException("Message must have content"); + } + + validateContentBlocks(message.getContent()); + } catch (MLValidationException e) { + throw new MLValidationException("Message at index " + index + " is invalid: " + e.getMessage()); + } + index++; + } + } + + /** + * Validates a single content block. + * + * @param block the content block to validate + * @throws MLValidationException if validation fails + */ + public void validateContentBlock(ContentBlock block) throws MLValidationException { + if (block == null) { + throw new MLValidationException("Content block cannot be null"); + } + + if (block.getType() == null) { + throw new MLValidationException("Content block must have a type"); + } + + switch (block.getType()) { + case TEXT: + if (block.getText() == null || block.getText().trim().isEmpty()) { + throw new MLValidationException("Text content block must have non-empty text field"); + } + break; + case IMAGE: + validateImageContent(block.getImage()); + break; + case VIDEO: + validateVideoContent(block.getVideo()); + break; + case DOCUMENT: + validateDocumentContent(block.getDocument()); + break; + default: + throw new MLValidationException("Unsupported content block type: " + block.getType()); + } + } + + /** + * Validates image content. + * + * @param imageContent the image content to validate + * @throws MLValidationException if validation fails + */ + public void validateImageContent(ImageContent imageContent) throws MLValidationException { + if (imageContent == null) { + throw new MLValidationException("Image content cannot be null for image content block"); + } + + if (imageContent.getType() == null) { + throw new MLValidationException("Image content must have a source type (URL or BASE64)"); + } + + if (imageContent.getFormat() == null || imageContent.getFormat().trim().isEmpty()) { + throw new MLValidationException("Image content must specify a format (e.g., jpeg, png, gif, webp)"); + } + + if (imageContent.getData() == null || imageContent.getData().trim().isEmpty()) { + throw new MLValidationException("Image content must have data (URL or base64 encoded data)"); + } + + // Validate format is reasonable for images + String format = imageContent.getFormat().toLowerCase(); + if (!format.matches("^(jpeg|jpg|png|gif|webp|bmp|tiff|svg)$")) { + throw new MLValidationException( + "Unsupported image format: " + imageContent.getFormat() + ". Supported formats: jpeg, jpg, png, gif, webp, bmp, tiff, svg" + ); + } + + // Basic validation for URL vs base64 + if (imageContent.getType() == SourceType.URL) { + if (!imageContent.getData().matches("^https?://.*")) { + throw new MLValidationException("URL source type requires data to be a valid HTTP/HTTPS URL"); + } + } else if (imageContent.getType() == SourceType.BASE64) { + // Basic base64 validation - should not contain spaces and have reasonable length + String data = imageContent.getData().trim(); + if (data.contains(" ") || data.length() < 4) { + throw new MLValidationException("BASE64 source type requires valid base64 encoded data"); + } + } + } + + /** + * Validates video content. + * + * @param videoContent the video content to validate + * @throws MLValidationException if validation fails + */ + public void validateVideoContent(VideoContent videoContent) throws MLValidationException { + if (videoContent == null) { + throw new MLValidationException("Video content cannot be null for video content block"); + } + + if (videoContent.getType() == null) { + throw new MLValidationException("Video content must have a source type (URL or BASE64)"); + } + + if (videoContent.getFormat() == null || videoContent.getFormat().trim().isEmpty()) { + throw new MLValidationException("Video content must specify a format (e.g., mp4, mov, avi)"); + } + + if (videoContent.getData() == null || videoContent.getData().trim().isEmpty()) { + throw new MLValidationException("Video content must have data (URL or base64 encoded data)"); + } + + // Validate format is reasonable for videos + String format = videoContent.getFormat().toLowerCase(); + if (!format.matches("^(mp4|mov|avi|mkv|wmv|flv|webm|m4v|3gp)$")) { + throw new MLValidationException( + "Unsupported video format: " + + videoContent.getFormat() + + ". Supported formats: mp4, mov, avi, mkv, wmv, flv, webm, m4v, 3gp" + ); + } + + // Basic validation for URL vs base64 + if (videoContent.getType() == SourceType.URL) { + if (!videoContent.getData().matches("^https?://.*")) { + throw new MLValidationException("URL source type requires data to be a valid HTTP/HTTPS URL"); + } + } else if (videoContent.getType() == SourceType.BASE64) { + // Basic base64 validation - should not contain spaces and have reasonable length + String data = videoContent.getData().trim(); + if (data.contains(" ") || data.length() < 4) { + throw new MLValidationException("BASE64 source type requires valid base64 encoded data"); + } + } + } + + /** + * Validates document content. + * + * @param documentContent the document content to validate + * @throws MLValidationException if validation fails + */ + public void validateDocumentContent(DocumentContent documentContent) throws MLValidationException { + if (documentContent == null) { + throw new MLValidationException("Document content cannot be null for document content block"); + } + + if (documentContent.getType() == null) { + throw new MLValidationException("Document content must have a source type (URL or BASE64)"); + } + + if (documentContent.getFormat() == null || documentContent.getFormat().trim().isEmpty()) { + throw new MLValidationException("Document content must specify a format (e.g., pdf, docx, txt)"); + } + + if (documentContent.getData() == null || documentContent.getData().trim().isEmpty()) { + throw new MLValidationException("Document content must have data (URL or base64 encoded data)"); + } + + // Validate format is reasonable for documents + String format = documentContent.getFormat().toLowerCase(); + if (!format.matches("^(pdf|docx|doc|txt|rtf|odt|html|xml|csv|xlsx|xls|pptx|ppt)$")) { + throw new MLValidationException( + "Unsupported document format: " + + documentContent.getFormat() + + ". Supported formats: pdf, docx, doc, txt, rtf, odt, html, xml, csv, xlsx, xls, pptx, ppt" + ); + } + + // Basic validation for URL vs base64 + if (documentContent.getType() == SourceType.URL) { + if (!documentContent.getData().matches("^https?://.*")) { + throw new MLValidationException("URL source type requires data to be a valid HTTP/HTTPS URL"); + } + } else if (documentContent.getType() == SourceType.BASE64) { + // Basic base64 validation - should not contain spaces and have reasonable length + String data = documentContent.getData().trim(); + if (data.contains(" ") || data.length() < 4) { + throw new MLValidationException("BASE64 source type requires valid base64 encoded data"); + } + } + } + + /** + * Validates plain text input. + * + * @param text the text to validate + * @throws MLValidationException if validation fails + */ + private void validateText(String text) throws MLValidationException { + if (text == null || text.trim().isEmpty()) { + throw new MLValidationException("Text input cannot be null or empty"); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index 73dcd3a7a8..1e2f3509f8 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -43,6 +43,7 @@ public class MLAgent implements ToXContentObject, Writeable { public static final String AGENT_TYPE_FIELD = "type"; public static final String DESCRIPTION_FIELD = "description"; public static final String LLM_FIELD = "llm"; + public static final String MODEL_FIELD = "model"; public static final String TOOLS_FIELD = "tools"; public static final String PARAMETERS_FIELD = "parameters"; public static final String MEMORY_FIELD = "memory"; @@ -63,6 +64,7 @@ public class MLAgent implements ToXContentObject, Writeable { private String type; private String description; private LLMSpec llm; + private MLAgentModelSpec model; private List tools; private Map parameters; private MLMemorySpec memory; @@ -79,6 +81,7 @@ public MLAgent( String type, String description, LLMSpec llm, + MLAgentModelSpec model, List tools, Map parameters, MLMemorySpec memory, @@ -92,6 +95,7 @@ public MLAgent( this.type = type; this.description = description; this.llm = llm; + this.model = model; this.tools = tools; this.parameters = parameters; this.memory = memory; @@ -104,6 +108,24 @@ public MLAgent( validate(); } + // Backward compatible constructor for existing tests + public MLAgent( + String name, + String type, + String description, + LLMSpec llm, + List tools, + Map parameters, + MLMemorySpec memory, + Instant createdTime, + Instant lastUpdateTime, + String appType, + Boolean isHidden, + String tenantId + ) { + this(name, type, description, llm, null, tools, parameters, memory, createdTime, lastUpdateTime, appType, isHidden, tenantId); + } + private void validate() { if (name == null) { throw new IllegalArgumentException("Agent name can't be null"); @@ -151,6 +173,9 @@ public MLAgent(StreamInput input) throws IOException { if (input.readBoolean()) { llm = new LLMSpec(input); } + if (input.readBoolean()) { + model = new MLAgentModelSpec(input); + } if (input.readBoolean()) { tools = new ArrayList<>(); int size = input.readInt(); @@ -186,6 +211,12 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + if (model != null) { + out.writeBoolean(true); + model.writeTo(out); + } else { + out.writeBoolean(false); + } if (tools != null && !tools.isEmpty()) { out.writeBoolean(true); out.writeInt(tools.size()); @@ -234,6 +265,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (llm != null) { builder.field(LLM_FIELD, llm); } + if (model != null) { + builder.field(MODEL_FIELD, model); + } if (tools != null && tools.size() > 0) { builder.field(TOOLS_FIELD, tools); } @@ -276,6 +310,7 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid String type = null; String description = null; LLMSpec llm = null; + MLAgentModelSpec model = null; List tools = null; Map parameters = null; MLMemorySpec memory = null; @@ -303,6 +338,9 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid case LLM_FIELD: llm = LLMSpec.parse(parser); break; + case MODEL_FIELD: + model = MLAgentModelSpec.parse(parser); + break; case TOOLS_FIELD: tools = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); @@ -344,6 +382,7 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid .type(type) .description(description) .llm(llm) + .model(model) .tools(tools) .parameters(parameters) .memory(memory) diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgentModelSpec.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgentModelSpec.java new file mode 100644 index 0000000000..f98a49b9e7 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgentModelSpec.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; + +/** + * Specification for model configuration in agent registration + */ +@EqualsAndHashCode +@Getter +@Setter +public class MLAgentModelSpec implements ToXContentObject { + public static final String MODEL_ID_FIELD = "model_id"; + public static final String MODEL_PROVIDER_FIELD = "model_provider"; + public static final String CREDENTIAL_FIELD = "credential"; + public static final String MODEL_PARAMETERS_FIELD = "model_parameters"; + + private final String modelId; + private final String modelProvider; + private Map credential; + private Map modelParameters; + + @Builder(toBuilder = true) + public MLAgentModelSpec(String modelId, String modelProvider, Map credential, Map modelParameters) { + if (modelId == null) { + throw new IllegalArgumentException("model_id must be provided"); + } + + if (modelProvider == null) { + throw new IllegalArgumentException("model_provider must be provided"); + } + this.modelId = modelId; + this.modelProvider = modelProvider; + this.credential = credential; + this.modelParameters = modelParameters; + } + + public MLAgentModelSpec(StreamInput input) throws IOException { + modelId = input.readString(); + modelProvider = input.readString(); + if (input.readBoolean()) { + credential = input.readMap(StreamInput::readString, StreamInput::readOptionalString); + } + if (input.readBoolean()) { + modelParameters = input.readMap(StreamInput::readString, StreamInput::readOptionalString); + } + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(modelProvider); + if (credential != null && !credential.isEmpty()) { + out.writeBoolean(true); + out.writeMap(credential, StreamOutput::writeString, StreamOutput::writeOptionalString); + } else { + out.writeBoolean(false); + } + if (modelParameters != null && !modelParameters.isEmpty()) { + out.writeBoolean(true); + out.writeMap(modelParameters, StreamOutput::writeString, StreamOutput::writeOptionalString); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (modelId != null) { + builder.field(MODEL_ID_FIELD, modelId); + } + if (modelProvider != null) { + builder.field(MODEL_PROVIDER_FIELD, modelProvider); + } + if (credential != null && !credential.isEmpty()) { + builder.field(CREDENTIAL_FIELD, credential); + } + if (modelParameters != null && !modelParameters.isEmpty()) { + builder.field(MODEL_PARAMETERS_FIELD, modelParameters); + } + builder.endObject(); + return builder; + } + + public static MLAgentModelSpec parse(XContentParser parser) throws IOException { + String model = null; + String modelProvider = null; + Map credential = null; + Map modelParameters = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MODEL_ID_FIELD: + model = parser.text(); + break; + case MODEL_PROVIDER_FIELD: + modelProvider = parser.text(); + break; + case CREDENTIAL_FIELD: + credential = getParameterMap(parser.map()); + break; + case MODEL_PARAMETERS_FIELD: + modelParameters = getParameterMap(parser.map()); + break; + default: + parser.skipChildren(); + break; + } + } + + return new MLAgentModelSpec(model, modelProvider, credential, modelParameters); + } + + public static MLAgentModelSpec fromStream(StreamInput in) throws IOException { + return new MLAgentModelSpec(in); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/Message.java b/common/src/main/java/org/opensearch/ml/common/agent/Message.java new file mode 100644 index 0000000000..1d17ab2675 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/Message.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import java.util.List; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Represents a message with role and content fields for conversation-style input. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class Message { + private String role; // flexible - any role allowed (user, assistant, system, etc.) + private List content; +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/ModelProvider.java b/common/src/main/java/org/opensearch/ml/common/agent/ModelProvider.java new file mode 100644 index 0000000000..eae4871548 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/ModelProvider.java @@ -0,0 +1,101 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; + +/** + * Abstract base class for model providers + */ +public abstract class ModelProvider { + + /** + * Creates a connector for this model provider + * @param modelName the model name (e.g., "us.anthropic.claude-3-7-sonnet-20250219-v1:0") + * @param credential credential map for the connector + * @param modelParameters additional model parameters + * @return configured Connector + */ + public abstract Connector createConnector(String modelName, Map credential, Map modelParameters); + + /** + * Creates MLRegisterModelInput for this model provider + * @param modelName the model name + * @param connector the connector to use + * @param modelParameters additional model parameters + * @return configured MLRegisterModelInput + */ + public abstract MLRegisterModelInput createModelInput(String modelName, Connector connector, Map modelParameters); + + /** + * Gets the LLM interface for function calling + * @return the LLM interface string, or null if not supported + */ + public abstract String getLLMInterface(); + + // Input mapping methods for converting standardized AgentInput to provider-specific parameters + + /** + * Maps simple text input to provider-specific request body parameters. + * Each provider implements this to convert text to their specific format. + * + * @param text the text input + * @return Map of parameters for the provider's request body template + */ + public abstract Map mapTextInput(String text); + + /** + * Maps multi-modal content blocks to provider-specific request body parameters. + * Each provider implements this to convert content blocks to their specific format. + * + * @param contentBlocks the list of content blocks + * @return Map of parameters for the provider's request body template + */ + public abstract Map mapContentBlocks(List contentBlocks); + + /** + * Maps message-based conversation to provider-specific request body parameters. + * Each provider implements this to convert messages to their specific format. + * + * @param messages the list of messages + * @return Map of parameters for the provider's request body template + */ + public abstract Map mapMessages(List messages); + + /** + * Maps standardized AgentInput to provider-specific request body parameters. + * This is the main entry point that delegates to the specific mapping methods. + * + * @param agentInput the standardized agent input + * @return Map of parameters for the provider's request body template + * @throws IllegalArgumentException if input type is unsupported + */ + public Map mapAgentInput(AgentInput agentInput) { + if (agentInput == null || agentInput.getInput() == null) { + throw new IllegalArgumentException("AgentInput and its input field cannot be null"); + } + + InputType inputType = agentInput.getInputType(); + return switch (inputType) { + case TEXT -> mapTextInput((String) agentInput.getInput()); + case CONTENT_BLOCKS -> { + @SuppressWarnings("unchecked") + List blocks = (List) agentInput.getInput(); + yield mapContentBlocks(blocks); + } + case MESSAGES -> { + @SuppressWarnings("unchecked") + List messages = (List) agentInput.getInput(); + yield mapMessages(messages); + } + default -> throw new IllegalArgumentException("Unsupported input type: " + inputType); + }; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/ModelProviderFactory.java b/common/src/main/java/org/opensearch/ml/common/agent/ModelProviderFactory.java new file mode 100644 index 0000000000..deda678854 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/ModelProviderFactory.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +/** + * Factory class for creating model providers + */ +// ToDo: modify this to a map that is automatically created using the enum +// Have the enum define the provider class +public class ModelProviderFactory { + + /** + * Get model provider instance based on provider type + * @param providerType the provider type string + * @return ModelProvider instance + * @throws IllegalArgumentException if provider type is not supported + */ + public static ModelProvider getProvider(String providerType) { + ModelProviderType type = ModelProviderType.from(providerType); + return switch (type) { + case BEDROCK_CONVERSE -> new BedrockConverseModelProvider(); + default -> throw new IllegalArgumentException("Unsupported model provider type: " + providerType); + }; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/ModelProviderType.java b/common/src/main/java/org/opensearch/ml/common/agent/ModelProviderType.java new file mode 100644 index 0000000000..9baf1aea34 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/ModelProviderType.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +/** + * Enum for supported model provider types + */ +public enum ModelProviderType { + BEDROCK_CONVERSE("bedrock/converse"), + OPENAI("openai"), + AZURE_OPENAI("azure/openai"); + + private final String value; + + ModelProviderType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + public static ModelProviderType from(String value) { + for (ModelProviderType type : ModelProviderType.values()) { + if (type.value.equalsIgnoreCase(value)) { + return type; + } + } + throw new IllegalArgumentException("Unknown model provider type: " + value); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/SourceType.java b/common/src/main/java/org/opensearch/ml/common/agent/SourceType.java new file mode 100644 index 0000000000..d276ffc24a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/SourceType.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +/** + * Enum representing the different ways content data can be provided. + */ +public enum SourceType { + URL, + BASE64 +} diff --git a/common/src/main/java/org/opensearch/ml/common/agent/VideoContent.java b/common/src/main/java/org/opensearch/ml/common/agent/VideoContent.java new file mode 100644 index 0000000000..ec11481b60 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/agent/VideoContent.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.agent; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Represents video content with type, format, and data fields. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class VideoContent { + private SourceType type; + private String format; // "mp4", "mov", "avi", etc. + private String data; // URL or base64 data +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java index 4d617ce896..c411a45138 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java @@ -22,9 +22,11 @@ import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.Setter; @Getter @EqualsAndHashCode +@Setter public class ConnectorClientConfig implements ToXContentObject, Writeable { public static final String MAX_CONNECTION_FIELD = "max_connection"; diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java index 986d6eefef..48a24bd324 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java @@ -8,6 +8,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.CommonValue.VERSION_2_19_0; +import static org.opensearch.ml.common.CommonValue.VERSION_3_3_0; import java.io.IOException; import java.util.Map; @@ -18,6 +19,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.agent.AgentInput; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -31,6 +33,7 @@ public class AgentMLInput extends MLInput { public static final String AGENT_ID_FIELD = "agent_id"; public static final String PARAMETERS_FIELD = "parameters"; + public static final String INPUT_FIELD = "input"; public static final String ASYNC_FIELD = "isAsync"; public static final Version MINIMAL_SUPPORTED_VERSION_FOR_ASYNC_EXECUTION = CommonValue.VERSION_3_0_0; @@ -47,6 +50,10 @@ public class AgentMLInput extends MLInput { @Setter private Boolean isAsync; + @Getter + @Setter + private AgentInput agentInput; + @Builder(builderMethodName = "AgentMLInputBuilder") public AgentMLInput(String agentId, String tenantId, FunctionName functionName, MLInputDataset inputDataset) { this(agentId, tenantId, functionName, inputDataset, false); @@ -59,6 +66,25 @@ public AgentMLInput(String agentId, String tenantId, FunctionName functionName, this.algorithm = functionName; this.inputDataset = inputDataset; this.isAsync = isAsync; + this.agentInput = null; // Legacy constructor - no standardized input + } + + // New constructor for standardized input + @Builder(builderMethodName = "AgentMLInputBuilderWithStandardInput") + public AgentMLInput( + String agentId, + String tenantId, + FunctionName functionName, + AgentInput agentInput, + MLInputDataset inputDataset, + Boolean isAsync + ) { + this.agentId = agentId; + this.tenantId = tenantId; + this.algorithm = functionName; + this.agentInput = agentInput; + this.inputDataset = inputDataset; + this.isAsync = isAsync; } @Override @@ -72,6 +98,13 @@ public void writeTo(StreamOutput out) throws IOException { if (streamOutputVersion.onOrAfter(AgentMLInput.MINIMAL_SUPPORTED_VERSION_FOR_ASYNC_EXECUTION)) { out.writeOptionalBoolean(isAsync); } + // Todo: finalize the version + if (streamOutputVersion.onOrAfter(VERSION_3_3_0)) { + out.writeBoolean(agentInput != null); + if (agentInput != null) { + agentInput.writeTo(out); + } + } } public AgentMLInput(StreamInput in) throws IOException { @@ -82,6 +115,11 @@ public AgentMLInput(StreamInput in) throws IOException { if (streamInputVersion.onOrAfter(AgentMLInput.MINIMAL_SUPPORTED_VERSION_FOR_ASYNC_EXECUTION)) { this.isAsync = in.readOptionalBoolean(); } + if (streamInputVersion.onOrAfter(VERSION_3_3_0)) { + if (in.readBoolean()) { + this.agentInput = new AgentInput(in); + } + } } public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOException { @@ -100,9 +138,13 @@ public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOE tenantId = parser.textOrNull(); break; case PARAMETERS_FIELD: + // Legacy format - parse parameters into RemoteInferenceInputDataSet Map parameters = StringUtils.getParameterMap(parser.map()); inputDataset = new RemoteInferenceInputDataSet(parameters); break; + case INPUT_FIELD: + agentInput = new AgentInput(parser); + break; case ASYNC_FIELD: isAsync = parser.booleanValue(); break; @@ -113,4 +155,11 @@ public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOE } } + /** + * Checks if this AgentMLInput uses the new standardized input format. + * @return true if AgentInput is present + */ + public boolean hasStandardInput() { + return agentInput != null; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 11b3a5b976..3ddfa8b214 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -51,8 +51,12 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.agent.AgentInput; +import org.opensearch.ml.common.agent.AgentInputProcessor; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLMemorySpec; +import org.opensearch.ml.common.agent.ModelProvider; +import org.opensearch.ml.common.agent.ModelProviderFactory; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; @@ -144,16 +148,14 @@ public void onMultiTenancyEnabledChanged(boolean isEnabled) { @Override public void execute(Input input, ActionListener listener, TransportChannel channel) { - if (!(input instanceof AgentMLInput)) { + if (!(input instanceof AgentMLInput agentMLInput)) { throw new IllegalArgumentException("wrong input"); } - AgentMLInput agentMLInput = (AgentMLInput) input; String agentId = agentMLInput.getAgentId(); String tenantId = agentMLInput.getTenantId(); Boolean isAsync = agentMLInput.getIsAsync(); - RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentMLInput.getInputDataset(); - if (inputDataSet == null || inputDataSet.getParameters() == null) { + if (agentMLInput.getInputDataset() == null && !agentMLInput.hasStandardInput()) { throw new IllegalArgumentException("Agent input data can not be empty."); } @@ -215,6 +217,11 @@ public void execute(Input input, ActionListener listener, TransportChann ) ); } + + processAgentInput(agentMLInput, mlAgent); + + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentMLInput + .getInputDataset(); MLMemorySpec memorySpec = mlAgent.getMemory(); String memoryId = inputDataSet.getParameters().get(MEMORY_ID); String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); @@ -728,4 +735,43 @@ private void updateInteractionWithFailure(String interactionId, Memory memory, S ); } } + + /** + * Processes standardized input if present in AgentMLInput. + * This method handles the conversion from AgentInput to parameters that can be used + * by the existing agent execution logic. + */ + private void processAgentInput(AgentMLInput agentMLInput, MLAgent mlAgent) { + // If legacy question input is provided, parse to new standard input + if (agentMLInput.getInputDataset() != null) { + RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) agentMLInput.getInputDataset(); + if (!remoteInferenceInputDataSet.getParameters().containsKey(QUESTION)) { + throw new IllegalArgumentException("Question not found in parameters."); + } + + AgentInput standardInput = new AgentInput(remoteInferenceInputDataSet.getParameters().get(QUESTION)); + agentMLInput.setAgentInput(standardInput); + } + + try { + // Extract the question text for prompt template and memory storage + String question = AgentInputProcessor.extractQuestionText(agentMLInput.getAgentInput()); + ModelProvider modelProvider = ModelProviderFactory.getProvider(mlAgent.getModel().getModelProvider()); + + // create input dataset if it doesn't exist + if (agentMLInput.getInputDataset() == null) { + agentMLInput.setInputDataset(new RemoteInferenceInputDataSet(new HashMap<>())); + } + + // Set parameters to processed params + RemoteInferenceInputDataSet remoteDataSet = (RemoteInferenceInputDataSet) agentMLInput.getInputDataset(); + Map parameters = modelProvider.mapAgentInput(agentMLInput.getAgentInput()); + // set question to questionText for memory + parameters.put(QUESTION, question); + remoteDataSet.getParameters().putAll(parameters); + } catch (Exception e) { + log.error("Failed to process standardized input for agent {}", mlAgent.getName(), e); + throw new IllegalArgumentException("Failed to process standardized agent input: " + e.getMessage(), e); + } + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index a8b7b54d71..90e2e4dfb8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -123,6 +123,7 @@ public class MLChatAgentRunner implements MLAgentRunner { public static final String INJECT_DATETIME_FIELD = "inject_datetime"; public static final String DATETIME_FORMAT_FIELD = "datetime_format"; public static final String SYSTEM_PROMPT_FIELD = "system_prompt"; + private static final String DEFAULT_SYSTEM_PROMPT = "You are an helpful assistant."; // empty system prompt private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; @@ -829,6 +830,19 @@ static Map constructLLMParams(LLMSpec llm, Map p tmpParameters.putIfAbsent(PROMPT_SUFFIX, PromptTemplate.PROMPT_TEMPLATE_SUFFIX); tmpParameters.putIfAbsent(RESPONSE_FORMAT_INSTRUCTION, PromptTemplate.PROMPT_FORMAT_INSTRUCTION); tmpParameters.putIfAbsent(TOOL_RESPONSE, PromptTemplate.PROMPT_TEMPLATE_TOOL_RESPONSE); + + // Set default system prompt only if none exists + if (!tmpParameters.containsKey(SYSTEM_PROMPT_FIELD)) { + String systemPrompt = DEFAULT_SYSTEM_PROMPT; + // If datetime injection was enabled, include it in the default system prompt + if (injectDate) { + String dateFormat = tmpParameters.get(DATETIME_FORMAT_FIELD); + String currentDateTime = getCurrentDateTime(dateFormat); + systemPrompt = systemPrompt + "\n\n" + currentDateTime; + } + tmpParameters.put(SYSTEM_PROMPT_FIELD, systemPrompt); + } + return tmpParameters; } diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java index e91d27c9bb..f4a4b8ff0e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java @@ -27,12 +27,18 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.MLAgentType; +import org.opensearch.ml.common.agent.AgentModelService; +import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLAgentModelSpec; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction; import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest; import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse; +import org.opensearch.ml.common.transport.register.MLRegisterModelAction; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelRequest; import org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner; import org.opensearch.ml.engine.function_calling.FunctionCallingFactory; import org.opensearch.ml.engine.indices.MLIndicesHandler; @@ -80,9 +86,49 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + try { + MLRegisterModelInput modelInput = AgentModelService.createModelFromSpec(mlAgent.getModel()); + MLRegisterModelRequest modelRequest = new MLRegisterModelRequest(modelInput); + + client.execute(MLRegisterModelAction.INSTANCE, modelRequest, ActionListener.wrap(modelResponse -> { + String modelId = modelResponse.getModelId(); + + Map parameters = new HashMap<>(); + if (mlAgent.getParameters() != null) { + parameters.putAll(mlAgent.getParameters()); + } + + String llmInterface = AgentModelService.inferLLMInterface(mlAgent.getModel().getModelProvider()); + if (llmInterface != null) { + parameters.put(LLM_INTERFACE, llmInterface); + } + + LLMSpec llmSpec = LLMSpec.builder().modelId(modelId).parameters(mlAgent.getModel().getModelParameters()).build(); + + // Remove credentials and model parameters as it is stored in the model document and LLMSpec respectively + MLAgentModelSpec modelSpec = mlAgent.getModel(); + modelSpec.setModelParameters(null); + modelSpec.setCredential(null); + // ToDo: store model details within agent to prevent creating a new model document + MLAgent agent = mlAgent.toBuilder().llm(llmSpec).model(modelSpec).parameters(parameters).build(); + registerAgent(agent, listener); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + } + private void registerAgent(MLAgent agent, ActionListener listener) { String mcpConnectorConfigJSON = (agent.getParameters() != null) ? agent.getParameters().get(MCP_CONNECTORS_FIELD) : null; if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) { From 4aef03d7ed45a008d92c26def43ddec3c1a23acb Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Wed, 29 Oct 2025 14:27:59 -0700 Subject: [PATCH 14/58] create internal connector in memory container Signed-off-by: Yaliang Wu --- .../common/memorycontainer/RemoteStore.java | 22 + .../index-mappings/ml_memory_container.json | 28 + .../TransportCreateMemoryContainerAction.java | 162 ++++- .../TransportGetMemoryContainerAction.java | 7 + .../TransportSearchMemoryContainerAction.java | 29 + .../TransportUpdateMemoryContainerAction.java | 7 +- .../ml/helper/MemoryContainerHelper.java | 106 +-- .../helper/MemoryContainerPipelineHelper.java | 122 ++-- ...lper.java => RemoteMemoryStoreHelper.java} | 630 +++++++++++++++--- .../ml/plugin/MachineLearningPlugin.java | 18 +- 10 files changed, 885 insertions(+), 246 deletions(-) rename plugin/src/main/java/org/opensearch/ml/helper/{RemoteStorageHelper.java => RemoteMemoryStoreHelper.java} (60%) diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java index 41d1262ace..c8f68f157b 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/RemoteStore.java @@ -20,6 +20,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.Connector; import lombok.Builder; import lombok.Data; @@ -33,6 +34,7 @@ public class RemoteStore implements ToXContentObject, Writeable { public static final String TYPE_FIELD = "type"; + public static final String CONNECTOR_FIELD = "connector"; public static final String CONNECTOR_ID_FIELD = "connector_id"; public static final String ENDPOINT_FIELD = "endpoint"; public static final String PARAMETERS_FIELD = "parameters"; @@ -42,6 +44,7 @@ public class RemoteStore implements ToXContentObject, Writeable { public static final String SEARCH_PIPELINE_FIELD = "search_pipeline"; private RemoteStoreType type; + private Connector connector; private String connectorId; private FunctionName embeddingModelType; private String embeddingModelId; @@ -62,6 +65,7 @@ public class RemoteStore implements ToXContentObject, Writeable { @Builder public RemoteStore( RemoteStoreType type, + Connector connector, String connectorId, FunctionName embeddingModelType, String embeddingModelId, @@ -77,6 +81,7 @@ public RemoteStore( throw new IllegalArgumentException("Invalid remote store type"); } this.type = type; + this.connector = connector; this.connectorId = connectorId; this.embeddingModelType = embeddingModelType; this.embeddingModelId = embeddingModelId; @@ -91,6 +96,9 @@ public RemoteStore( public RemoteStore(StreamInput input) throws IOException { this.type = input.readEnum(RemoteStoreType.class); + if (input.readBoolean()) { + this.connector = Connector.fromStream(input); + } this.connectorId = input.readOptionalString(); if (input.readOptionalBoolean()) { this.embeddingModelType = input.readEnum(FunctionName.class); @@ -118,6 +126,12 @@ public RemoteStore(StreamInput input) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { out.writeEnum(type); + if (connector != null) { + out.writeBoolean(true); + connector.writeTo(out); + } else { + out.writeBoolean(false); + } out.writeOptionalString(connectorId); if (embeddingModelType != null) { out.writeBoolean(true); @@ -156,6 +170,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (type != null) { builder.field(TYPE_FIELD, type); } + if (connector != null) { + builder.field(CONNECTOR_FIELD, connector); + } if (connectorId != null) { builder.field(CONNECTOR_ID_FIELD, connectorId); } @@ -190,6 +207,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static RemoteStore parse(XContentParser parser) throws IOException { RemoteStoreType type = null; + Connector connector = null; String connectorId = null; FunctionName embeddingModelType = null; String embeddingModelId = null; @@ -210,6 +228,9 @@ public static RemoteStore parse(XContentParser parser) throws IOException { case TYPE_FIELD: type = RemoteStoreType.fromString(parser.text()); break; + case CONNECTOR_FIELD: + connector = Connector.createConnector(parser); + break; case CONNECTOR_ID_FIELD: connectorId = parser.text(); break; @@ -249,6 +270,7 @@ public static RemoteStore parse(XContentParser parser) throws IOException { return RemoteStore .builder() .type(type) + .connector(connector) .connectorId(connectorId) .embeddingModelType(embeddingModelType) .embeddingModelId(embeddingModelId) diff --git a/common/src/main/resources/index-mappings/ml_memory_container.json b/common/src/main/resources/index-mappings/ml_memory_container.json index a39f7e0e24..14f8843111 100644 --- a/common/src/main/resources/index-mappings/ml_memory_container.json +++ b/common/src/main/resources/index-mappings/ml_memory_container.json @@ -58,8 +58,36 @@ "type": { "type": "keyword" }, + "endpoint": { + "type": "keyword" + }, + "parameters": { + "type": "flat_object" + }, + "credential": { + "type": "flat_object" + }, + "connector": CONNECTOR_MAPPING_PLACEHOLDER, "connector_id": { "type": "keyword" + }, + "embedding_model_id": { + "type": "keyword" + }, + "embedding_model_type": { + "type": "keyword" + }, + "embedding_dimension": { + "type": "integer" + }, + "embedding_model": { + "type": "flat_object" + }, + "ingest_pipeline": { + "type": "keyword" + }, + "search_pipeline": { + "type": "keyword" } } }, diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java index fed445e9d0..3b04623a12 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java @@ -7,16 +7,17 @@ import static org.opensearch.ml.common.CommonValue.ML_MEMORY_CONTAINER_INDEX; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE; +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.determineProtocol; -import static org.opensearch.ml.helper.RemoteStorageHelper.BULK_LOAD_ACTION; -import static org.opensearch.ml.helper.RemoteStorageHelper.CREATE_INDEX_ACTION; -import static org.opensearch.ml.helper.RemoteStorageHelper.CREATE_INGEST_PIPELINE_ACTION; -import static org.opensearch.ml.helper.RemoteStorageHelper.DELETE_DOC_ACTION; -import static org.opensearch.ml.helper.RemoteStorageHelper.GET_DOC_ACTION; -import static org.opensearch.ml.helper.RemoteStorageHelper.REGISTER_MODEL_ACTION; -import static org.opensearch.ml.helper.RemoteStorageHelper.SEARCH_INDEX_ACTION; -import static org.opensearch.ml.helper.RemoteStorageHelper.UPDATE_DOC_ACTION; -import static org.opensearch.ml.helper.RemoteStorageHelper.WRITE_DOC_ACTION; +import static org.opensearch.ml.helper.RemoteMemoryStoreHelper.BULK_LOAD_ACTION; +import static org.opensearch.ml.helper.RemoteMemoryStoreHelper.CREATE_INDEX_ACTION; +import static org.opensearch.ml.helper.RemoteMemoryStoreHelper.CREATE_INGEST_PIPELINE_ACTION; +import static org.opensearch.ml.helper.RemoteMemoryStoreHelper.DELETE_DOC_ACTION; +import static org.opensearch.ml.helper.RemoteMemoryStoreHelper.GET_DOC_ACTION; +import static org.opensearch.ml.helper.RemoteMemoryStoreHelper.REGISTER_MODEL_ACTION; +import static org.opensearch.ml.helper.RemoteMemoryStoreHelper.SEARCH_INDEX_ACTION; +import static org.opensearch.ml.helper.RemoteMemoryStoreHelper.UPDATE_DOC_ACTION; +import static org.opensearch.ml.helper.RemoteMemoryStoreHelper.WRITE_DOC_ACTION; import java.time.Instant; import java.util.ArrayList; @@ -31,11 +32,17 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.WriteRequest; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.memorycontainer.MLMemoryContainer; import org.opensearch.ml.common.memorycontainer.MemoryConfiguration; @@ -43,16 +50,18 @@ import org.opensearch.ml.common.memorycontainer.RemoteEmbeddingModel; import org.opensearch.ml.common.memorycontainer.RemoteStore; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerAction; import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerInput; import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerRequest; import org.opensearch.ml.common.transport.memorycontainer.MLCreateMemoryContainerResponse; +import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.helper.MemoryContainerModelValidator; import org.opensearch.ml.helper.MemoryContainerPipelineHelper; import org.opensearch.ml.helper.MemoryContainerSharedIndexValidator; -import org.opensearch.ml.helper.RemoteStorageHelper; +import org.opensearch.ml.helper.RemoteMemoryStoreHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.ml.utils.TenantAwareHelper; @@ -74,29 +83,49 @@ public class TransportCreateMemoryContainerAction extends private final MLIndicesHandler mlIndicesHandler; private final Client client; + private final ClusterService clusterService; + private final Settings settings; private final SdkClient sdkClient; private final ConnectorAccessControlHelper connectorAccessControlHelper; private final MLFeatureEnabledSetting mlFeatureEnabledSetting; private final MLModelManager mlModelManager; + private final MLEngine mlEngine; + private final RemoteMemoryStoreHelper remoteMemoryStoreHelper; + private final MemoryContainerPipelineHelper memoryContainerPipelineHelper; + private volatile List trustedConnectorEndpointsRegex; @Inject public TransportCreateMemoryContainerAction( TransportService transportService, ActionFilters actionFilters, Client client, + ClusterService clusterService, + Settings settings, SdkClient sdkClient, MLIndicesHandler mlIndicesHandler, ConnectorAccessControlHelper connectorAccessControlHelper, MLFeatureEnabledSetting mlFeatureEnabledSetting, - MLModelManager mlModelManager + MLModelManager mlModelManager, + MLEngine mlEngine, + RemoteMemoryStoreHelper remoteMemoryStoreHelper, + MemoryContainerPipelineHelper memoryContainerPipelineHelper ) { super(MLCreateMemoryContainerAction.NAME, transportService, actionFilters, MLCreateMemoryContainerRequest::new); this.client = client; this.sdkClient = sdkClient; this.mlIndicesHandler = mlIndicesHandler; + this.clusterService = clusterService; + this.settings = settings; this.connectorAccessControlHelper = connectorAccessControlHelper; this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; this.mlModelManager = mlModelManager; + this.mlEngine = mlEngine; + this.remoteMemoryStoreHelper = remoteMemoryStoreHelper; + trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); + this.memoryContainerPipelineHelper = memoryContainerPipelineHelper; + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, it -> trustedConnectorEndpointsRegex = it); } @Override @@ -296,7 +325,7 @@ private void createMemoryIndexes( } private void createLongTermMemoryIngestPipeline(String indexName, MemoryConfiguration memoryConfig, ActionListener listener) { - MemoryContainerPipelineHelper.createLongTermMemoryIngestPipeline(indexName, memoryConfig, mlIndicesHandler, client, listener); + memoryContainerPipelineHelper.createLongTermMemoryIngestPipeline(indexName, memoryConfig, listener); } private void indexMemoryContainer(MLMemoryContainer container, ActionListener listener) { @@ -347,26 +376,57 @@ private void indexMemoryContainer(MLMemoryContainer container, ActionListener listener) { + if (config.getRemoteStore() != null && config.getRemoteStore().getConnector() != null) { + if (config.getRemoteStore().getEmbeddingModel() != null + && (config.getRemoteStore().getIngestPipeline() == null || config.getRemoteStore().getIngestPipeline().isEmpty())) { + remoteMemoryStoreHelper + .createRemoteEmbeddingModel( + config.getRemoteStore().getConnector(), + config.getRemoteStore().getEmbeddingModel(), + config.getRemoteStore().getCredential(), + ActionListener.wrap(modelId -> { + // Set the embedding model ID in the remote store config + config.getRemoteStore().setEmbeddingModelId(modelId); + // Also set type and dimension from embedding model config + RemoteEmbeddingModel embModel = config.getRemoteStore().getEmbeddingModel(); + config.getRemoteStore().setEmbeddingModelType(embModel.getModelType()); + config.getRemoteStore().setEmbeddingDimension(embModel.getDimension()); + log.info("Auto-created embedding model with ID: {} in remote store", modelId); + // Continue with normal validation + validateConfigurationInternal(config, listener); + }, listener::onFailure) + ); + } else { + if (config.getRemoteStore().getIngestPipeline() != null && !config.getRemoteStore().getIngestPipeline().isEmpty()) { + log + .info( + "Skipping embedding model auto-creation, using pre-existing ingest pipeline: {}", + config.getRemoteStore().getIngestPipeline() + ); + } + // Continue with normal validation + validateConfigurationInternal(config, listener); + } + return; + } // Check if we need to auto-create a connector if (config.getRemoteStore() != null && config.getRemoteStore().getConnectorId() == null && config.getRemoteStore().getEndpoint() != null) { // Auto-create connector first - createConnectorForRemoteStore(config.getRemoteStore(), ActionListener.wrap(connectorId -> { + createInternalConnectorForRemoteStore(config.getRemoteStore(), ActionListener.wrap(connector -> { // Set the connector ID in the remote store config - config.getRemoteStore().setConnectorId(connectorId); - log.info("Auto-created connector with ID: {} for remote store", connectorId); + config.getRemoteStore().setConnector(connector); // Check if we need to auto-create embedding model // Skip if user provided a pre-existing ingest pipeline if (config.getRemoteStore().getEmbeddingModel() != null && (config.getRemoteStore().getIngestPipeline() == null || config.getRemoteStore().getIngestPipeline().isEmpty())) { - RemoteStorageHelper + remoteMemoryStoreHelper .createRemoteEmbeddingModel( - connectorId, + connector, config.getRemoteStore().getEmbeddingModel(), config.getRemoteStore().getCredential(), - client, ActionListener.wrap(modelId -> { // Set the embedding model ID in the remote store config config.getRemoteStore().setEmbeddingModelId(modelId); @@ -447,13 +507,13 @@ private void validateConfigurationInternal(MemoryConfiguration config, ActionLis } private void createRemoteSessionMemoryIndex(MemoryConfiguration configuration, String indexName, ActionListener listener) { - String connectorId = configuration.getRemoteStore().getConnectorId(); - RemoteStorageHelper.createRemoteSessionMemoryIndex(connectorId, indexName, configuration, mlIndicesHandler, client, listener); + Connector connector = configuration.getRemoteStore().getConnector(); + remoteMemoryStoreHelper.createRemoteSessionMemoryIndex(connector, indexName, configuration, listener); } private void createRemoteWorkingMemoryIndex(MemoryConfiguration configuration, String indexName, ActionListener listener) { - String connectorId = configuration.getRemoteStore().getConnectorId(); - RemoteStorageHelper.createRemoteWorkingMemoryIndex(connectorId, indexName, configuration, mlIndicesHandler, client, listener); + Connector connector = configuration.getRemoteStore().getConnector(); + remoteMemoryStoreHelper.createRemoteWorkingMemoryIndex(connector, indexName, configuration, listener); } private void createRemoteLongTermMemoryHistoryIndex( @@ -461,9 +521,8 @@ private void createRemoteLongTermMemoryHistoryIndex( String indexName, ActionListener listener ) { - String connectorId = configuration.getRemoteStore().getConnectorId(); - RemoteStorageHelper - .createRemoteLongTermMemoryHistoryIndex(connectorId, indexName, configuration, mlIndicesHandler, client, listener); + Connector connector = configuration.getRemoteStore().getConnector(); + remoteMemoryStoreHelper.createRemoteLongTermMemoryHistoryIndex(connector, indexName, configuration, listener); } private void createRemoteMemoryIndexes( @@ -493,9 +552,7 @@ private void createRemoteLongTermMemoryIngestPipeline( String indexName, ActionListener listener ) { - String connectorId = configuration.getRemoteStore().getConnectorId(); - MemoryContainerPipelineHelper - .createRemoteLongTermMemoryIngestPipeline(connectorId, indexName, configuration, mlIndicesHandler, client, listener); + memoryContainerPipelineHelper.createRemoteLongTermMemoryIngestPipeline(indexName, configuration, listener); } /** @@ -542,6 +599,14 @@ private void createConnectorForRemoteStore(RemoteStore remoteStore, ActionListen request, ActionListener.wrap(response -> { log.info("Successfully created connector: {}", response.getConnectorId()); + // // Store the connector object in remote store configuration + // org.opensearch.ml.common.connector.Connector connector = connectorInput.toConnector(); + // // Encrypt connector credentials before storing (similar to indexRemoteModel in MLModelManager) + // String tenantId = remoteStore.getParameters() != null + // ? remoteStore.getParameters().get("tenant_id") + // : null; + // connector.encrypt(mlEngine::encrypt, tenantId); + // remoteStore.setConnector(connector); listener.onResponse(response.getConnectorId()); }, e -> { log.error("Failed to create connector for remote store", e); @@ -554,6 +619,47 @@ private void createConnectorForRemoteStore(RemoteStore remoteStore, ActionListen } } + private void createInternalConnectorForRemoteStore(RemoteStore remoteStore, ActionListener listener) { + try { + String connectorName = "auto_" + + remoteStore.getType().name().toLowerCase() + + "_connector_" + + UUID.randomUUID().toString().substring(0, 8); + + // Build connector actions based on remote store type + List actions = buildConnectorActions(remoteStore); + + // Get credential and parameters from remote store + Map credential = remoteStore.getCredential(); + Map parameters = remoteStore.getParameters(); + + // Determine protocol based on parameters or credential + String protocol = determineProtocol(parameters, credential); + + // Create connector input + MLCreateConnectorInput connectorInput = MLCreateConnectorInput + .builder() + .name(connectorName) + .description("Auto-generated connector for " + remoteStore.getType() + " remote memory store") + .version("1") + .protocol(protocol) + .parameters(parameters) + .credential(credential) + .actions(actions) + .build(); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + connectorInput.toXContent(builder, ToXContent.EMPTY_PARAMS); + Connector connector = Connector.createConnector(builder, connectorInput.getProtocol()); + connector.validateConnectorURL(trustedConnectorEndpointsRegex); + connector.encrypt(mlEngine::encrypt, null); + listener.onResponse(connector); + } catch (Exception e) { + log.error("Error building connector for remote store", e); + listener.onFailure(e); + } + } + /** * Builds connector actions based on remote store type */ diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportGetMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportGetMemoryContainerAction.java index a87ac9bbd0..7e367ff750 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportGetMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportGetMemoryContainerAction.java @@ -169,6 +169,13 @@ private void processResponse( ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLMemoryContainer mlMemoryContainer = MLMemoryContainer.parse(parser); + if (mlMemoryContainer != null + && mlMemoryContainer.getConfiguration() != null + && mlMemoryContainer.getConfiguration().getRemoteStore() != null + && mlMemoryContainer.getConfiguration().getRemoteStore().getConnector() != null) { + mlMemoryContainer.getConfiguration().getRemoteStore().getConnector().removeCredential(); + } + if (TenantAwareHelper .validateTenantResource(mlFeatureEnabledSetting, tenantId, mlMemoryContainer.getTenantId(), wrappedListener)) { validateMemoryContainerAccess(user, memoryContainerId, mlMemoryContainer, wrappedListener); diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportSearchMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportSearchMemoryContainerAction.java index 7d870ce4f1..23417e643e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportSearchMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportSearchMemoryContainerAction.java @@ -9,6 +9,12 @@ import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -19,6 +25,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.memorycontainer.MLMemoryContainerSearchAction; import org.opensearch.ml.common.transport.search.MLSearchActionRequest; @@ -28,6 +35,8 @@ import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.SearchDataObjectRequest; import org.opensearch.remote.metadata.common.SdkClientUtils; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; @@ -92,6 +101,26 @@ private void preProcessRoleAndPerformSearch( memoryContainerHelper.addUserBackendRolesFilter(user, request.source()); log.debug("Filtering result by {}", user.getBackendRoles()); } + + // Exclude credential fields from connector in remote_store configuration + List excludes = Optional + .ofNullable(request.source()) + .map(SearchSourceBuilder::fetchSource) + .map(FetchSourceContext::excludes) + .map(x -> Arrays.stream(x).collect(Collectors.toList())) + .orElse(new ArrayList<>()); + excludes.add("configuration.remote_store.connector." + HttpConnector.CREDENTIAL_FIELD); + FetchSourceContext rebuiltFetchSourceContext = new FetchSourceContext( + Optional + .ofNullable(request.source()) + .map(SearchSourceBuilder::fetchSource) + .map(FetchSourceContext::fetchSource) + .orElse(true), + Optional.ofNullable(request.source()).map(SearchSourceBuilder::fetchSource).map(FetchSourceContext::includes).orElse(null), + excludes.toArray(new String[0]) + ); + request.source().fetchSource(rebuiltFetchSourceContext); + SearchDataObjectRequest searchDataObjecRequest = SearchDataObjectRequest .builder() .indices(request.indices()) diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java index 72ad3e32de..db6d55a412 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java @@ -66,6 +66,7 @@ public class TransportUpdateMemoryContainerAction extends HandledTransportAction final MLModelManager mlModelManager; final MemoryContainerHelper memoryContainerHelper; final MLIndicesHandler mlIndicesHandler; + final MemoryContainerPipelineHelper memoryContainerPipelineHelper; @Inject public TransportUpdateMemoryContainerAction( @@ -78,7 +79,8 @@ public TransportUpdateMemoryContainerAction( MLFeatureEnabledSetting mlFeatureEnabledSetting, MLModelManager mlModelManager, MemoryContainerHelper memoryContainerHelper, - MLIndicesHandler mlIndicesHandler + MLIndicesHandler mlIndicesHandler, + MemoryContainerPipelineHelper memoryContainerPipelineHelper ) { super(MLUpdateMemoryContainerAction.NAME, transportService, actionFilters, MLUpdateMemoryContainerRequest::new); this.client = client; @@ -89,6 +91,7 @@ public TransportUpdateMemoryContainerAction( this.mlModelManager = mlModelManager; this.memoryContainerHelper = memoryContainerHelper; this.mlIndicesHandler = mlIndicesHandler; + this.memoryContainerPipelineHelper = memoryContainerPipelineHelper; } @Override @@ -394,6 +397,6 @@ private void createLongTermAndHistoryIndices( * Creates ingest pipeline and long-term index. */ private void createLongTermMemoryIngestPipeline(String indexName, MemoryConfiguration config, ActionListener listener) { - MemoryContainerPipelineHelper.createLongTermMemoryIngestPipeline(indexName, config, mlIndicesHandler, client, listener); + memoryContainerPipelineHelper.createLongTermMemoryIngestPipeline(indexName, config, listener); } } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java index 47fa2d3b12..7047f2a37d 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java @@ -67,6 +67,7 @@ import org.opensearch.ml.common.memorycontainer.MemoryConfiguration; import org.opensearch.ml.common.memorycontainer.MemoryStrategy; import org.opensearch.ml.common.memorycontainer.MemoryType; +import org.opensearch.ml.common.memorycontainer.RemoteStore; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.remote.metadata.client.GetDataObjectRequest; import org.opensearch.remote.metadata.client.SdkClient; @@ -90,12 +91,30 @@ public class MemoryContainerHelper { Client client; SdkClient sdkClient; NamedXContentRegistry xContentRegistry; + RemoteMemoryStoreHelper remoteMemoryStoreHelper; @Inject - public MemoryContainerHelper(Client client, SdkClient sdkClient, NamedXContentRegistry xContentRegistry) { + public MemoryContainerHelper( + Client client, + SdkClient sdkClient, + NamedXContentRegistry xContentRegistry, + RemoteMemoryStoreHelper remoteMemoryStoreHelper + ) { this.client = client; this.sdkClient = sdkClient; this.xContentRegistry = xContentRegistry; + this.remoteMemoryStoreHelper = remoteMemoryStoreHelper; + } + + /** + * Check if remote store is configured with either connectorId or internal connector + * + * @param configuration the memory configuration + * @return true if remote store is configured + */ + private boolean hasRemoteStore(MemoryConfiguration configuration) { + return configuration.getRemoteStore() != null + && (configuration.getRemoteStore().getConnectorId() != null || configuration.getRemoteStore().getConnector() != null); } /** @@ -234,7 +253,8 @@ public String getMemoryIndexName(MLMemoryContainer container, MemoryType memoryT } public void getData(MemoryConfiguration configuration, GetRequest getRequest, ActionListener listener) { - if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + if (configuration.getRemoteStore() != null + && (configuration.getRemoteStore().getConnectorId() != null || configuration.getRemoteStore().getConnector() != null)) { getDataFromRemoteStorage(configuration, getRequest, listener); } else if (configuration.isUseSystemIndex()) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -247,17 +267,16 @@ public void getData(MemoryConfiguration configuration, GetRequest getRequest, Ac private void getDataFromRemoteStorage(MemoryConfiguration configuration, GetRequest getRequest, ActionListener listener) { try { - String connectorId = configuration.getRemoteStore().getConnectorId(); + RemoteStore remoteStore = configuration.getRemoteStore(); String indexName = getRequest.indices()[0]; String docId = getRequest.id(); // Convert SearchSourceBuilder to Map - RemoteStorageHelper + remoteMemoryStoreHelper .getDocument( - connectorId, + remoteStore, indexName, docId, - client, ActionListener.wrap(response -> { listener.onResponse(response); }, listener::onFailure) ); } catch (Exception e) { @@ -272,8 +291,9 @@ public void searchData( ActionListener listener ) { try { - // Check if remote store is configured - if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + // Check if remote store is configured (either with connectorId or internal connector) + if (configuration.getRemoteStore() != null + && (configuration.getRemoteStore().getConnectorId() != null || configuration.getRemoteStore().getConnector() != null)) { // Use remote storage // searchDataFromRemoteStorage(configuration, searchRequest, listener); throw new RuntimeException("Remote store is not yet implemented"); @@ -304,9 +324,9 @@ public void searchDataFromRemoteStorage( ActionListener listener ) { try { - String connectorId = configuration.getRemoteStore().getConnectorId(); - String searchPipeline = configuration.getRemoteStore().getSearchPipeline(); - RemoteStorageHelper.searchDocuments(connectorId, indexName, query, searchPipeline, client, ActionListener.wrap(response -> { + RemoteStore remoteStore = configuration.getRemoteStore(); + String searchPipeline = remoteStore.getSearchPipeline(); + remoteMemoryStoreHelper.searchDocuments(remoteStore, indexName, query, searchPipeline, ActionListener.wrap(response -> { listener.onResponse(response); }, listener::onFailure)); } catch (Exception e) { @@ -328,8 +348,9 @@ private Map convertSearchSourceToMap(SearchSourceBuilder searchS } public void indexData(MemoryConfiguration configuration, IndexRequest indexRequest, ActionListener listener) { - // Check if remote store is configured - if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + // Check if remote store is configured (either with connectorId or internal connector) + if (configuration.getRemoteStore() != null + && (configuration.getRemoteStore().getConnectorId() != null || configuration.getRemoteStore().getConnector() != null)) { // Use remote storage indexDataToRemoteStorage(configuration, indexRequest, listener); } else if (configuration.isUseSystemIndex()) { @@ -347,25 +368,24 @@ public void updateDataToRemoteStorage( ActionListener listener ) { try { - String connectorId = configuration.getRemoteStore().getConnectorId(); + RemoteStore remoteStore = configuration.getRemoteStore(); String indexName = indexRequest.index(); String docId = indexRequest.id(); // Convert IndexRequest source to Map Map documentSource = indexRequest.sourceAsMap(); - RemoteStorageHelper - .updateDocument(connectorId, indexName, docId, documentSource, client, ActionListener.wrap(updateResponse -> { - IndexResponse response = new IndexResponse( - updateResponse.getShardId(), - updateResponse.getId(), - updateResponse.getSeqNo(), - updateResponse.getPrimaryTerm(), - updateResponse.getVersion(), - false - ); - listener.onResponse(response); - }, listener::onFailure)); + remoteMemoryStoreHelper.updateDocument(remoteStore, indexName, docId, documentSource, ActionListener.wrap(updateResponse -> { + IndexResponse response = new IndexResponse( + updateResponse.getShardId(), + updateResponse.getId(), + updateResponse.getSeqNo(), + updateResponse.getPrimaryTerm(), + updateResponse.getVersion(), + false + ); + listener.onResponse(response); + }, listener::onFailure)); } catch (Exception e) { log.error("Failed to index data to remote storage", e); listener.onFailure(e); @@ -378,12 +398,12 @@ private void indexDataToRemoteStorage( ActionListener listener ) { try { - String connectorId = configuration.getRemoteStore().getConnectorId(); + RemoteStore remoteStore = configuration.getRemoteStore(); String indexName = indexRequest.index(); // Convert IndexRequest source to Map Map documentSource = indexRequest.sourceAsMap(); - RemoteStorageHelper.writeDocument(connectorId, indexName, documentSource, client, listener); + remoteMemoryStoreHelper.writeDocument(remoteStore, indexName, documentSource, listener); } catch (Exception e) { log.error("Failed to index data to remote storage", e); listener.onFailure(e); @@ -391,7 +411,8 @@ private void indexDataToRemoteStorage( } public void updateData(MemoryConfiguration configuration, UpdateRequest updateRequest, ActionListener listener) { - if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + if (configuration.getRemoteStore() != null + && (configuration.getRemoteStore().getConnectorId() != null || configuration.getRemoteStore().getConnector() != null)) { updateDataInRemoteStorage(configuration, updateRequest, listener); } else if (configuration.isUseSystemIndex()) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -408,12 +429,12 @@ private void updateDataInRemoteStorage( ActionListener listener ) { try { - String connectorId = configuration.getRemoteStore().getConnectorId(); + RemoteStore remoteStore = configuration.getRemoteStore(); String indexName = updateRequest.index(); String docId = updateRequest.id(); Map documentSource = convertUpdateRequestToMap(updateRequest); - RemoteStorageHelper.updateDocument(connectorId, indexName, docId, documentSource, client, ActionListener.wrap(response -> { + remoteMemoryStoreHelper.updateDocument(remoteStore, indexName, docId, documentSource, ActionListener.wrap(response -> { listener.onResponse(response); }, listener::onFailure)); } catch (Exception e) { @@ -439,7 +460,8 @@ private Map convertUpdateRequestToMap(UpdateRequest updateReques } public void deleteData(MemoryConfiguration configuration, DeleteRequest deleteRequest, ActionListener listener) { - if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + if (configuration.getRemoteStore() != null + && (configuration.getRemoteStore().getConnectorId() != null || configuration.getRemoteStore().getConnector() != null)) { deleteDataFromRemoteStorage(configuration, deleteRequest, listener); } else if (configuration.isUseSystemIndex()) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -456,16 +478,15 @@ private void deleteDataFromRemoteStorage( ActionListener listener ) { try { - String connectorId = configuration.getRemoteStore().getConnectorId(); + RemoteStore remoteStore = configuration.getRemoteStore(); String indexName = deleteRequest.index(); String docId = deleteRequest.id(); - RemoteStorageHelper + remoteMemoryStoreHelper .deleteDocument( - connectorId, + remoteStore, indexName, docId, - client, ActionListener.wrap(response -> { listener.onResponse(response); }, listener::onFailure) ); } catch (Exception e) { @@ -489,7 +510,8 @@ public void deleteIndex( } public void bulkIngestData(MemoryConfiguration configuration, BulkRequest bulkRequest, ActionListener listener) { - if (configuration.getRemoteStore() != null && configuration.getRemoteStore().getConnectorId() != null) { + if (configuration.getRemoteStore() != null + && (configuration.getRemoteStore().getConnectorId() != null || configuration.getRemoteStore().getConnector() != null)) { bulkIngestDataToRemoteStorage(configuration, bulkRequest, listener); } else if (configuration.isUseSystemIndex()) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -506,7 +528,7 @@ private void bulkIngestDataToRemoteStorage( ActionListener listener ) { try { - String connectorId = configuration.getRemoteStore().getConnectorId(); + RemoteStore remoteStore = configuration.getRemoteStore(); List bulkBodyList = convertBulkRequestToNDJSON(bulkRequest); if (bulkBodyList.isEmpty()) { @@ -515,7 +537,7 @@ private void bulkIngestDataToRemoteStorage( } // Process sequentially - bulkIngestSequentially(connectorId, bulkBodyList, 0, new ArrayList<>(), listener); + bulkIngestSequentially(remoteStore, bulkBodyList, 0, new ArrayList<>(), listener); } catch (Exception e) { log.error("Failed to bulk ingest data to remote storage", e); @@ -524,7 +546,7 @@ private void bulkIngestDataToRemoteStorage( } private void bulkIngestSequentially( - String connectorId, + RemoteStore remoteStore, List bulkBodyList, int index, List responses, @@ -537,10 +559,10 @@ private void bulkIngestSequentially( return; } - RemoteStorageHelper.bulkWrite(connectorId, bulkBodyList.get(index), client, ActionListener.wrap(response -> { + remoteMemoryStoreHelper.bulkWrite(remoteStore, bulkBodyList.get(index), ActionListener.wrap(response -> { responses.add(response); // Process next - bulkIngestSequentially(connectorId, bulkBodyList, index + 1, responses, finalListener); + bulkIngestSequentially(remoteStore, bulkBodyList, index + 1, responses, finalListener); }, finalListener::onFailure)); } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java index 44e7959f71..25f838864c 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerPipelineHelper.java @@ -18,13 +18,12 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.memorycontainer.MemoryConfiguration; import org.opensearch.ml.common.memorycontainer.RemoteStore; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.transport.client.Client; -import lombok.AccessLevel; -import lombok.NoArgsConstructor; import lombok.extern.log4j.Log4j2; /** @@ -32,8 +31,21 @@ * Provides reusable pipeline creation logic for long-term memory indices. */ @Log4j2 -@NoArgsConstructor(access = AccessLevel.PRIVATE) -public final class MemoryContainerPipelineHelper { +public class MemoryContainerPipelineHelper { + + private final Client client; + private final MLIndicesHandler mlIndicesHandler; + private final RemoteMemoryStoreHelper remoteMemoryStoreHelper; + + public MemoryContainerPipelineHelper( + Client client, + MLIndicesHandler mlIndicesHandler, + RemoteMemoryStoreHelper remoteMemoryStoreHelper + ) { + this.client = client; + this.mlIndicesHandler = mlIndicesHandler; + this.remoteMemoryStoreHelper = remoteMemoryStoreHelper; + } /** * Creates an ingest pipeline and long-term memory index. @@ -45,30 +57,22 @@ public final class MemoryContainerPipelineHelper { * * @param indexName The long-term memory index name * @param config The memory configuration - * @param indicesHandler The ML indices handler - * @param client The OpenSearch client * @param listener Action listener that receives true on success, or error on failure */ - public static void createLongTermMemoryIngestPipeline( - String indexName, - MemoryConfiguration config, - MLIndicesHandler indicesHandler, - Client client, - ActionListener listener - ) { + public void createLongTermMemoryIngestPipeline(String indexName, MemoryConfiguration config, ActionListener listener) { try { // Check if user provided a pre-existing ingest pipeline at configuration level if (config.getIngestPipeline() != null && !config.getIngestPipeline().isEmpty()) { log.info("Using pre-existing ingest pipeline from configuration: {}", config.getIngestPipeline()); // Use the user-provided pipeline directly - indicesHandler.createLongTermMemoryIndex(config.getIngestPipeline(), indexName, config, listener); + mlIndicesHandler.createLongTermMemoryIndex(config.getIngestPipeline(), indexName, config, listener); } else if (config.getEmbeddingModelType() != null) { // Auto-create pipeline if embedding model is configured String pipelineName = indexName + "-embedding"; - createTextEmbeddingPipeline(pipelineName, config, client, ActionListener.wrap(success -> { + createTextEmbeddingPipeline(pipelineName, config, ActionListener.wrap(success -> { log.info("Successfully created text embedding pipeline: {}", pipelineName); - indicesHandler.createLongTermMemoryIndex(pipelineName, indexName, config, listener); + mlIndicesHandler.createLongTermMemoryIndex(pipelineName, indexName, config, listener); }, e -> { log.error("Failed to create text embedding pipeline '{}'", pipelineName, e); listener @@ -80,7 +84,7 @@ public static void createLongTermMemoryIngestPipeline( ); })); } else { - indicesHandler.createLongTermMemoryIndex(null, indexName, config, listener); + mlIndicesHandler.createLongTermMemoryIndex(null, indexName, config, listener); } } catch (Exception e) { log.error("Failed to create text embedding pipeline for long term memory index: {}", indexName, e); @@ -102,15 +106,9 @@ public static void createLongTermMemoryIngestPipeline( * * @param pipelineName The pipeline name * @param config The memory configuration - * @param client The OpenSearch client * @param listener Action listener that receives true on success, or error on failure */ - public static void createTextEmbeddingPipeline( - String pipelineName, - MemoryConfiguration config, - Client client, - ActionListener listener - ) { + public void createTextEmbeddingPipeline(String pipelineName, MemoryConfiguration config, ActionListener listener) { // Check if pipeline already exists (shared index scenario) client.admin().cluster().getPipeline(new GetPipelineRequest(pipelineName), ActionListener.wrap(response -> { if (!response.pipelines().isEmpty()) { @@ -122,7 +120,7 @@ public static void createTextEmbeddingPipeline( // Pipeline doesn't exist - create it try { - createPipelineInternal(pipelineName, config, client, listener); + createPipelineInternal(pipelineName, config, listener); } catch (IOException e) { log.error("Failed to build pipeline configuration for '{}'", pipelineName, e); listener @@ -136,7 +134,7 @@ public static void createTextEmbeddingPipeline( }, error -> { // Pipeline doesn't exist (404 error expected) - create it try { - createPipelineInternal(pipelineName, config, client, listener); + createPipelineInternal(pipelineName, config, listener); } catch (IOException e) { log.error("Failed to build pipeline configuration for '{}'", pipelineName, e); listener @@ -155,16 +153,11 @@ public static void createTextEmbeddingPipeline( * * @param pipelineName The pipeline name * @param config The memory configuration - * @param client The OpenSearch client * @param listener Action listener that receives true on success, or error on failure * @throws IOException if XContentBuilder fails */ - private static void createPipelineInternal( - String pipelineName, - MemoryConfiguration config, - Client client, - ActionListener listener - ) throws IOException { + private void createPipelineInternal(String pipelineName, MemoryConfiguration config, ActionListener listener) + throws IOException { String processorName = config.getEmbeddingModelType() == FunctionName.TEXT_EMBEDDING ? "text_embedding" : "sparse_encoding"; XContentBuilder builder = XContentFactory @@ -218,18 +211,12 @@ private static void createPipelineInternal( * * @param config The memory configuration * @param historyIndexName The history index name - * @param indicesHandler The ML indices handler * @param listener Action listener that receives true on success, or error on failure */ - public static void createHistoryIndexIfEnabled( - MemoryConfiguration config, - String historyIndexName, - MLIndicesHandler indicesHandler, - ActionListener listener - ) { + public void createHistoryIndexIfEnabled(MemoryConfiguration config, String historyIndexName, ActionListener listener) { if (!config.isDisableHistory()) { log.debug("Creating history index: {}", historyIndexName); - indicesHandler.createLongTermMemoryHistoryIndex(historyIndexName, config, listener); + mlIndicesHandler.createLongTermMemoryHistoryIndex(historyIndexName, config, listener); } else { log.debug("History index disabled, skipping creation"); listener.onResponse(true); @@ -244,64 +231,57 @@ public static void createHistoryIndexIfEnabled( * then creates the long-term index with the pipeline attached. * If no embedding is configured, creates the index without a pipeline. * - * @param connectorId The connector ID for remote storage * @param indexName The long-term memory index name * @param config The memory configuration - * @param indicesHandler The ML indices handler - * @param client The OpenSearch client * @param listener Action listener that receives true on success, or error on failure */ - public static void createRemoteLongTermMemoryIngestPipeline( - String connectorId, - String indexName, - MemoryConfiguration config, - MLIndicesHandler indicesHandler, - Client client, - ActionListener listener - ) { + public void createRemoteLongTermMemoryIngestPipeline(String indexName, MemoryConfiguration config, ActionListener listener) { try { RemoteStore remoteStore = config.getRemoteStore(); + Connector connector = remoteStore.getConnector(); + if (connector != null) { + remoteMemoryStoreHelper + .createRemoteLongTermMemoryIndexWithIngestAndSearchPipeline( + connector, + indexName, + remoteStore.getIngestPipeline(), + remoteStore.getSearchPipeline(), + config, + listener + ); + return; + } + String connectorId = remoteStore.getConnectorId(); // Check if user provided a pre-existing ingest pipeline in remote_store if (remoteStore.getIngestPipeline() != null && !remoteStore.getIngestPipeline().isEmpty()) { log.info("Using pre-existing ingest pipeline from remote_store: {}", remoteStore.getIngestPipeline()); // Use the user-provided pipeline directly - org.opensearch.ml.helper.RemoteStorageHelper + remoteMemoryStoreHelper .createRemoteLongTermMemoryIndexWithIngestAndSearchPipeline( connectorId, indexName, remoteStore.getIngestPipeline(), remoteStore.getSearchPipeline(), config, - indicesHandler, - client, listener ); } else if (remoteStore.getEmbeddingModelType() != null) { // Auto-create pipeline if embedding model is configured String pipelineName = indexName + "-embedding"; - createRemoteTextEmbeddingPipeline(connectorId, pipelineName, config, client, ActionListener.wrap(success -> { + createRemoteTextEmbeddingPipeline(connectorId, pipelineName, config, ActionListener.wrap(success -> { log.info("Successfully created remote text embedding pipeline: {}", pipelineName); // Now create the remote long-term memory index with the pipeline - org.opensearch.ml.helper.RemoteStorageHelper - .createRemoteLongTermMemoryIndexWithPipeline( - connectorId, - indexName, - pipelineName, - config, - indicesHandler, - client, - listener - ); + remoteMemoryStoreHelper + .createRemoteLongTermMemoryIndexWithPipeline(connectorId, indexName, pipelineName, config, listener); }, e -> { log.error("Failed to create remote text embedding pipeline '{}'", pipelineName, e); listener.onFailure(e); })); } else { // No embedding configured, create index without pipeline - org.opensearch.ml.helper.RemoteStorageHelper - .createRemoteLongTermMemoryIndex(connectorId, indexName, config, indicesHandler, client, listener); + remoteMemoryStoreHelper.createRemoteLongTermMemoryIndex(connectorId, indexName, config, listener); } } catch (Exception e) { log.error("Failed to create remote long-term memory infrastructure for index: {}", indexName, e); @@ -319,14 +299,12 @@ public static void createRemoteLongTermMemoryIngestPipeline( * @param connectorId The connector ID for remote storage * @param pipelineName The pipeline name * @param config The memory configuration - * @param client The OpenSearch client * @param listener Action listener that receives true on success, or error on failure */ - public static void createRemoteTextEmbeddingPipeline( + public void createRemoteTextEmbeddingPipeline( String connectorId, String pipelineName, MemoryConfiguration config, - Client client, ActionListener listener ) { try { @@ -356,7 +334,7 @@ public static void createRemoteTextEmbeddingPipeline( String pipelineBody = builder.toString(); // Use RemoteStorageHelper to create the pipeline in remote storage - org.opensearch.ml.helper.RemoteStorageHelper.createRemotePipeline(connectorId, pipelineName, pipelineBody, client, listener); + remoteMemoryStoreHelper.createRemotePipeline(connectorId, pipelineName, pipelineBody, listener); } catch (IOException e) { log.error("Failed to build remote pipeline configuration for '{}'", pipelineName, e); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java similarity index 60% rename from plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java rename to plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java index a830d4a56a..0e199708c9 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/RemoteStorageHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java @@ -40,6 +40,7 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.update.UpdateResponse; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; @@ -51,6 +52,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; @@ -58,11 +60,16 @@ import org.opensearch.ml.common.memorycontainer.MemoryStrategy; import org.opensearch.ml.common.memorycontainer.RemoteStore; import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.MLEngineClassLoader; import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; +import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.script.ScriptService; import org.opensearch.transport.client.Client; import lombok.extern.log4j.Log4j2; @@ -71,7 +78,7 @@ * Helper class for creating memory indices in remote storage using connectors */ @Log4j2 -public class RemoteStorageHelper { +public class RemoteMemoryStoreHelper { public static final String REGISTER_MODEL_ACTION = "register_model"; public static final String CREATE_INGEST_PIPELINE_ACTION = "create_ingest_pipeline"; @@ -88,23 +95,43 @@ public class RemoteStorageHelper { public static final String HEADERS_FIELD = "headers"; public static final String ACTIONS_FIELD = "actions"; + private final Client client; + private final ClusterService clusterService; + private final ScriptService scriptService; + private final NamedXContentRegistry xContentRegistry; + private final Encryptor encryptor; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + private final MLIndicesHandler mlIndicesHandler; + + public RemoteMemoryStoreHelper( + Client client, + ClusterService clusterService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Encryptor encryptor, + MLFeatureEnabledSetting mlFeatureEnabledSetting, + MLIndicesHandler mlIndicesHandler + ) { + + this.client = client; + this.clusterService = clusterService; + this.scriptService = scriptService; + this.xContentRegistry = xContentRegistry; + this.encryptor = encryptor; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + this.mlIndicesHandler = mlIndicesHandler; + } + /** * Creates a memory index in remote storage using a connector * * @param connectorId The connector ID to use for remote storage * @param indexName The name of the index to create * @param indexMapping The index mapping as a JSON string - * @param client The OpenSearch client * @param listener The action listener */ - public static void createRemoteIndex( - String connectorId, - String indexName, - String indexMapping, - Client client, - ActionListener listener - ) { - createRemoteIndex(connectorId, indexName, indexMapping, null, client, listener); + public void createRemoteIndex(String connectorId, String indexName, String indexMapping, ActionListener listener) { + createRemoteIndex(connectorId, indexName, indexMapping, null, listener); } /** @@ -114,15 +141,13 @@ public static void createRemoteIndex( * @param indexName The name of the index to create * @param indexMapping The index mapping as a JSON string * @param indexSettings The index settings as a Map (can be null) - * @param client The OpenSearch client * @param listener The action listener */ - public static void createRemoteIndex( + public void createRemoteIndex( String connectorId, String indexName, String indexMapping, Map indexSettings, - Client client, ActionListener listener ) { try { @@ -145,7 +170,48 @@ public static void createRemoteIndex( parameters.put(CONNECTOR_ACTION_FIELD, CREATE_INDEX_ACTION); // Execute the connector action - executeConnectorAction(connectorId, parameters, client, ActionListener.wrap(response -> { + executeConnectorAction(connectorId, CREATE_INDEX_ACTION, parameters, ActionListener.wrap(response -> { + log.info("Successfully created remote index: {}", indexName); + listener.onResponse(true); + }, e -> { + log.error("Failed to create remote index: {}", indexName, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote index creation for: {}", indexName, e); + listener.onFailure(e); + } + } + + public void createRemoteIndex( + Connector connector, + String indexName, + String indexMapping, + Map indexSettings, + ActionListener listener + ) { + try { + // Parse the mapping string to a Map + Map mappingMap = parseMappingToMap(indexMapping); + + // Build the request body for creating the index + Map requestBody = new HashMap<>(); + requestBody.put("mappings", mappingMap); + + // Add settings if provided (settings should already have "index." prefix) + if (indexSettings != null && !indexSettings.isEmpty()) { + requestBody.put("settings", indexSettings); + } + + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(INPUT_PARAM, StringUtils.toJson(requestBody)); + parameters.put(CONNECTOR_ACTION_FIELD, CREATE_INDEX_ACTION); + + // Execute the connector action + executeConnectorAction(connector, CREATE_INDEX_ACTION, parameters, ActionListener.wrap(response -> { log.info("Successfully created remote index: {}", indexName); listener.onResponse(true); }, e -> { @@ -162,66 +228,91 @@ public static void createRemoteIndex( /** * Creates session memory index in remote storage */ - public static void createRemoteSessionMemoryIndex( + public void createRemoteSessionMemoryIndex( String connectorId, String indexName, MemoryConfiguration configuration, - MLIndicesHandler mlIndicesHandler, - Client client, ActionListener listener ) { String indexMappings = mlIndicesHandler.getMapping(ML_MEMORY_SESSION_INDEX_MAPPING_PATH); Map indexSettings = configuration.getMemoryIndexMapping(SESSION_INDEX); - createRemoteIndex(connectorId, indexName, indexMappings, indexSettings, client, listener); + createRemoteIndex(connectorId, indexName, indexMappings, indexSettings, listener); + } + + public void createRemoteSessionMemoryIndex( + Connector connector, + String indexName, + MemoryConfiguration configuration, + ActionListener listener + ) { + String indexMappings = mlIndicesHandler.getMapping(ML_MEMORY_SESSION_INDEX_MAPPING_PATH); + Map indexSettings = configuration.getMemoryIndexMapping(SESSION_INDEX); + createRemoteIndex(connector, indexName, indexMappings, indexSettings, listener); } /** * Creates working memory index in remote storage */ - public static void createRemoteWorkingMemoryIndex( + public void createRemoteWorkingMemoryIndex( String connectorId, String indexName, MemoryConfiguration configuration, - MLIndicesHandler mlIndicesHandler, - Client client, ActionListener listener ) { String indexMappings = mlIndicesHandler.getMapping(ML_WORKING_MEMORY_INDEX_MAPPING_PATH); Map indexSettings = configuration.getMemoryIndexMapping(WORKING_MEMORY_INDEX); - createRemoteIndex(connectorId, indexName, indexMappings, indexSettings, client, listener); + createRemoteIndex(connectorId, indexName, indexMappings, indexSettings, listener); + } + + public void createRemoteWorkingMemoryIndex( + Connector connector, + String indexName, + MemoryConfiguration configuration, + ActionListener listener + ) { + String indexMappings = mlIndicesHandler.getMapping(ML_WORKING_MEMORY_INDEX_MAPPING_PATH); + Map indexSettings = configuration.getMemoryIndexMapping(WORKING_MEMORY_INDEX); + createRemoteIndex(connector, indexName, indexMappings, indexSettings, listener); } /** * Creates long-term memory history index in remote storage */ - public static void createRemoteLongTermMemoryHistoryIndex( + public void createRemoteLongTermMemoryHistoryIndex( String connectorId, String indexName, MemoryConfiguration configuration, - MLIndicesHandler mlIndicesHandler, - Client client, ActionListener listener ) { String indexMappings = mlIndicesHandler.getMapping(ML_LONG_MEMORY_HISTORY_INDEX_MAPPING_PATH); Map indexSettings = configuration.getMemoryIndexMapping(LONG_TERM_MEMORY_HISTORY_INDEX); - createRemoteIndex(connectorId, indexName, indexMappings, indexSettings, client, listener); + createRemoteIndex(connectorId, indexName, indexMappings, indexSettings, listener); + } + + public void createRemoteLongTermMemoryHistoryIndex( + Connector connector, + String indexName, + MemoryConfiguration configuration, + ActionListener listener + ) { + String indexMappings = mlIndicesHandler.getMapping(ML_LONG_MEMORY_HISTORY_INDEX_MAPPING_PATH); + Map indexSettings = configuration.getMemoryIndexMapping(LONG_TERM_MEMORY_HISTORY_INDEX); + createRemoteIndex(connector, indexName, indexMappings, indexSettings, listener); } /** * Creates long-term memory index in remote storage with dynamic embedding configuration */ - public static void createRemoteLongTermMemoryIndex( + public void createRemoteLongTermMemoryIndex( String connectorId, String indexName, MemoryConfiguration memoryConfig, - MLIndicesHandler mlIndicesHandler, - Client client, ActionListener listener ) { try { String indexMapping = buildLongTermMemoryMapping(memoryConfig, mlIndicesHandler); Map indexSettings = buildLongTermMemorySettings(memoryConfig); - createRemoteIndex(connectorId, indexName, indexMapping, indexSettings, client, listener); + createRemoteIndex(connectorId, indexName, indexMapping, indexSettings, listener); } catch (Exception e) { log.error("Failed to build long-term memory mapping for remote index: {}", indexName, e); listener.onFailure(e); @@ -231,8 +322,7 @@ public static void createRemoteLongTermMemoryIndex( /** * Builds the long-term memory index mapping dynamically based on configuration */ - private static String buildLongTermMemoryMapping(MemoryConfiguration memoryConfig, MLIndicesHandler mlIndicesHandler) - throws IOException { + private String buildLongTermMemoryMapping(MemoryConfiguration memoryConfig, MLIndicesHandler mlIndicesHandler) throws IOException { String baseMappingJson = mlIndicesHandler.getMapping(ML_LONG_TERM_MEMORY_INDEX_MAPPING_PATH); Map mapping = new HashMap<>(); @@ -269,7 +359,7 @@ private static String buildLongTermMemoryMapping(MemoryConfiguration memoryConfi * Builds the long-term memory index settings dynamically based on configuration * Returns settings with "index." prefix as required by OpenSearch/AOSS */ - private static Map buildLongTermMemorySettings(MemoryConfiguration memoryConfig) { + private Map buildLongTermMemorySettings(MemoryConfiguration memoryConfig) { Map indexSettings = new HashMap<>(); RemoteStore remoteStore = memoryConfig.getRemoteStore(); @@ -289,11 +379,10 @@ private static Map buildLongTermMemorySettings(MemoryConfigurati /** * Executes a connector action with a specific action name */ - private static void executeConnectorAction( + private void executeConnectorAction( String connectorId, String actionName, Map parameters, - Client client, ActionListener listener ) { // Add connector_action parameter to specify which action to execute @@ -313,16 +402,37 @@ private static void executeConnectorAction( })); } - /** - * Executes a connector action (backward compatibility - defaults to create_index) - */ - private static void executeConnectorAction( - String connectorId, + private void executeConnectorAction( + Connector connector, + String actionName, Map parameters, - Client client, ActionListener listener ) { - executeConnectorAction(connectorId, CREATE_INDEX_ACTION, parameters, client, listener); + // Add connector_action parameter to specify which action to execute + Map allParameters = new HashMap<>(parameters); + allParameters.put(CONNECTOR_ACTION_FIELD, actionName); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(allParameters).build(); + MLInput mlInput = RemoteInferenceMLInput.builder().algorithm(FunctionName.CONNECTOR).inputDataset(inputDataSet).build(); + + runConnector(connector, actionName, mlInput, listener); + } + + private void runConnector(Connector connector, String actionName, MLInput mlInput, ActionListener actionListener) { + if (connector == null) { + throw new IllegalArgumentException("connector is null"); + } + // adding tenantID as null, because we are not implement multi-tenancy for this feature yet. + connector.decrypt(actionName, (credential, tenantId) -> encryptor.decrypt(credential, null), null); + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setConnectorPrivateIpEnabled(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor.executeAction(actionName, mlInput, ActionListener.wrap(taskResponse -> { + actionListener.onResponse((ModelTensorOutput) taskResponse.getOutput()); + }, e -> { actionListener.onFailure(e); })); } /** @@ -334,11 +444,34 @@ private static void executeConnectorAction( * @param client The OpenSearch client * @param listener The action listener */ - public static void writeDocument( + /** + * Writes a single document to remote storage using RemoteStore configuration + * + * @param remoteStore The remote store configuration (supports both connectorId and internal connector) + * @param indexName The name of the index + * @param documentSource The document source as a Map + * @param listener The action listener + */ + public void writeDocument( + RemoteStore remoteStore, + String indexName, + Map documentSource, + ActionListener listener + ) { + // If connectorId is provided, use the existing method + if (remoteStore.getConnectorId() != null) { + writeDocument(remoteStore.getConnectorId(), indexName, documentSource, listener); + } else if (remoteStore.getConnector() != null) { + writeDocument(remoteStore.getConnector(), indexName, documentSource, listener); + } else { + listener.onFailure(new IllegalArgumentException("RemoteStore must have either connectorId or internal connector configured")); + } + } + + public void writeDocument( String connectorId, String indexName, Map documentSource, - Client client, ActionListener listener ) { try { @@ -348,7 +481,36 @@ public static void writeDocument( parameters.put(INPUT_PARAM, StringUtils.toJsonWithPlainNumbers(documentSource)); // Execute the connector action with write_doc action name - executeConnectorAction(connectorId, WRITE_DOC_ACTION, parameters, client, ActionListener.wrap(response -> { + executeConnectorAction(connectorId, WRITE_DOC_ACTION, parameters, ActionListener.wrap(response -> { + // Extract document ID from response + XContentParser parser = createParserFromTensorOutput(response); + IndexResponse indexResponse = IndexResponse.fromXContent(parser); + listener.onResponse(indexResponse); + }, e -> { + log.error("Failed to write document to remote index: {}", indexName, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote document write for index: {}", indexName, e); + listener.onFailure(e); + } + } + + public void writeDocument( + Connector connector, + String indexName, + Map documentSource, + ActionListener listener + ) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(INPUT_PARAM, StringUtils.toJsonWithPlainNumbers(documentSource)); + + // Execute the connector action with write_doc action name + executeConnectorAction(connector, WRITE_DOC_ACTION, parameters, ActionListener.wrap(response -> { // Extract document ID from response XContentParser parser = createParserFromTensorOutput(response); IndexResponse indexResponse = IndexResponse.fromXContent(parser); @@ -364,15 +526,51 @@ public static void writeDocument( } } + /** + * Performs bulk write operations to remote storage using RemoteStore configuration + */ + public void bulkWrite(RemoteStore remoteStore, String bulkBody, ActionListener listener) { + if (remoteStore.getConnectorId() != null) { + bulkWrite(remoteStore.getConnectorId(), bulkBody, listener); + } else if (remoteStore.getConnector() != null) { + bulkWrite(remoteStore.getConnector(), bulkBody, listener); + } else { + listener.onFailure(new IllegalArgumentException("RemoteStore must have either connectorId or internal connector configured")); + } + } + /** * Performs bulk write operations to remote storage * * @param connectorId The connector ID to use for remote storage * @param bulkBody The bulk request body in NDJSON format - * @param client The OpenSearch client * @param listener The action listener */ - public static void bulkWrite(String connectorId, String bulkBody, Client client, ActionListener listener) { + public void bulkWrite(String connectorId, String bulkBody, ActionListener listener) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INPUT_PARAM, bulkBody); + parameters.put(NO_ESCAPE_PARAMS, INPUT_PARAM); + + // Execute the connector action with bulk_load action name + executeConnectorAction(connectorId, BULK_LOAD_ACTION, parameters, ActionListener.wrap(response -> { + log.info("Successfully executed bulk write to remote storage"); + XContentParser parser = createParserFromTensorOutput(response); + BulkResponse bulkResponse = BulkResponse.fromXContent(parser); + listener.onResponse(bulkResponse); + }, e -> { + log.error("Failed to execute bulk write to remote storage", e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote bulk write", e); + listener.onFailure(e); + } + } + + public void bulkWrite(Connector connector, String bulkBody, ActionListener listener) { try { // Prepare parameters for connector execution Map parameters = new HashMap<>(); @@ -380,7 +578,7 @@ public static void bulkWrite(String connectorId, String bulkBody, Client client, parameters.put(NO_ESCAPE_PARAMS, INPUT_PARAM); // Execute the connector action with bulk_load action name - executeConnectorAction(connectorId, BULK_LOAD_ACTION, parameters, client, ActionListener.wrap(response -> { + executeConnectorAction(connector, BULK_LOAD_ACTION, parameters, ActionListener.wrap(response -> { log.info("Successfully executed bulk write to remote storage"); XContentParser parser = createParserFromTensorOutput(response); BulkResponse bulkResponse = BulkResponse.fromXContent(parser); @@ -396,12 +594,30 @@ public static void bulkWrite(String connectorId, String bulkBody, Client client, } } - public static void searchDocuments( + /** + * Searches documents in remote storage using RemoteStore configuration + */ + public void searchDocuments( + RemoteStore remoteStore, + String indexName, + String query, + String searchPipeline, + ActionListener listener + ) { + if (remoteStore.getConnectorId() != null) { + searchDocuments(remoteStore.getConnectorId(), indexName, query, searchPipeline, listener); + } else if (remoteStore.getConnector() != null) { + searchDocuments(remoteStore.getConnector(), indexName, query, searchPipeline, listener); + } else { + listener.onFailure(new IllegalArgumentException("RemoteStore must have either connectorId or internal connector configured")); + } + } + + public void searchDocuments( String connectorId, String indexName, String query, String searchPipeline, - Client client, ActionListener listener ) { try { @@ -414,7 +630,40 @@ public static void searchDocuments( parameters.put(INPUT_PARAM, query); // Execute the connector action with search_index action name - executeConnectorAction(connectorId, SEARCH_INDEX_ACTION, parameters, client, ActionListener.wrap(response -> { + executeConnectorAction(connectorId, SEARCH_INDEX_ACTION, parameters, ActionListener.wrap(response -> { + log.info("Successfully searched documents in remote index: {}", indexName); + XContentParser parser = createParserFromTensorOutput(response); + SearchResponse searchResponse = SearchResponse.fromXContent(parser); + listener.onResponse(searchResponse); + }, e -> { + log.error("Failed to search documents in remote index: {}", indexName, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote search for index: {}", indexName, e); + listener.onFailure(e); + } + } + + public void searchDocuments( + Connector connector, + String indexName, + String query, + String searchPipeline, + ActionListener listener + ) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + if (searchPipeline != null) { + parameters.put(SEARCH_PIPELINE_FIELD, "?search_pipeline=" + searchPipeline); + } + parameters.put(INPUT_PARAM, query); + + // Execute the connector action with search_index action name + executeConnectorAction(connector, SEARCH_INDEX_ACTION, parameters, ActionListener.wrap(response -> { log.info("Successfully searched documents in remote index: {}", indexName); XContentParser parser = createParserFromTensorOutput(response); SearchResponse searchResponse = SearchResponse.fromXContent(parser); @@ -430,6 +679,25 @@ public static void searchDocuments( } } + /** + * Updates a document in remote storage using RemoteStore configuration + */ + public void updateDocument( + RemoteStore remoteStore, + String indexName, + String docId, + Map documentSource, + ActionListener listener + ) { + if (remoteStore.getConnectorId() != null) { + updateDocument(remoteStore.getConnectorId(), indexName, docId, documentSource, listener); + } else if (remoteStore.getConnector() != null) { + updateDocument(remoteStore.getConnector(), indexName, docId, documentSource, listener); + } else { + listener.onFailure(new IllegalArgumentException("RemoteStore must have either connectorId or internal connector configured")); + } + } + /** * Updates a document in remote storage * @@ -437,15 +705,13 @@ public static void searchDocuments( * @param indexName The name of the index * @param docId The document ID to update * @param documentSource The document source as a Map - * @param client The OpenSearch client * @param listener The action listener */ - public static void updateDocument( + public void updateDocument( String connectorId, String indexName, String docId, Map documentSource, - Client client, ActionListener listener ) { try { @@ -456,7 +722,7 @@ public static void updateDocument( parameters.put(INPUT_PARAM, StringUtils.toJsonWithPlainNumbers(documentSource)); // Execute the connector action with update_doc action name - executeConnectorAction(connectorId, UPDATE_DOC_ACTION, parameters, client, ActionListener.wrap(response -> { + executeConnectorAction(connectorId, UPDATE_DOC_ACTION, parameters, ActionListener.wrap(response -> { log.info("Successfully updated document in remote index: {}, doc_id: {}", indexName, docId); XContentParser parser = createParserFromTensorOutput(response); UpdateResponse updateResponse = UpdateResponse.fromXContent(parser); @@ -472,13 +738,77 @@ public static void updateDocument( } } - public static void getDocument( - String connectorId, + public void updateDocument( + Connector connector, String indexName, String docId, - Client client, - ActionListener listener + Map documentSource, + ActionListener listener ) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(DOC_ID_PARAM, docId); + parameters.put(INPUT_PARAM, StringUtils.toJsonWithPlainNumbers(documentSource)); + + // Execute the connector action with update_doc action name + executeConnectorAction(connector, UPDATE_DOC_ACTION, parameters, ActionListener.wrap(response -> { + log.info("Successfully updated document in remote index: {}, doc_id: {}", indexName, docId); + XContentParser parser = createParserFromTensorOutput(response); + UpdateResponse updateResponse = UpdateResponse.fromXContent(parser); + listener.onResponse(updateResponse); + }, e -> { + log.error("Failed to update document in remote index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote document update for index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + } + } + + /** + * Gets a document from remote storage using RemoteStore configuration + */ + public void getDocument(RemoteStore remoteStore, String indexName, String docId, ActionListener listener) { + if (remoteStore.getConnectorId() != null) { + getDocument(remoteStore.getConnectorId(), indexName, docId, listener); + } else if (remoteStore.getConnector() != null) { + getDocument(remoteStore.getConnector(), indexName, docId, listener); + } else { + listener.onFailure(new IllegalArgumentException("RemoteStore must have either connectorId or internal connector configured")); + } + } + + public void getDocument(String connectorId, String indexName, String docId, ActionListener listener) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(DOC_ID_PARAM, docId); + // input parameter is optional for delete, use empty string as default + parameters.put(INPUT_PARAM, ""); + + // Execute the connector action with delete_doc action name + executeConnectorAction(connectorId, GET_DOC_ACTION, parameters, ActionListener.wrap(response -> { + log.info("Successfully deleted document from remote index: {}, doc_id: {}", indexName, docId); + XContentParser parser = createParserFromTensorOutput(response); + GetResponse getResponse = GetResponse.fromXContent(parser); + listener.onResponse(getResponse); + }, e -> { + log.error("Failed to delete document from remote index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote document delete for index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + } + } + + public void getDocument(Connector connector, String indexName, String docId, ActionListener listener) { try { // Prepare parameters for connector execution Map parameters = new HashMap<>(); @@ -488,7 +818,7 @@ public static void getDocument( parameters.put(INPUT_PARAM, ""); // Execute the connector action with delete_doc action name - executeConnectorAction(connectorId, GET_DOC_ACTION, parameters, client, ActionListener.wrap(response -> { + executeConnectorAction(connector, GET_DOC_ACTION, parameters, ActionListener.wrap(response -> { log.info("Successfully deleted document from remote index: {}, doc_id: {}", indexName, docId); XContentParser parser = createParserFromTensorOutput(response); GetResponse getResponse = GetResponse.fromXContent(parser); @@ -504,22 +834,28 @@ public static void getDocument( } } + /** + * Deletes a document from remote storage using RemoteStore configuration + */ + public void deleteDocument(RemoteStore remoteStore, String indexName, String docId, ActionListener listener) { + if (remoteStore.getConnectorId() != null) { + deleteDocument(remoteStore.getConnectorId(), indexName, docId, listener); + } else if (remoteStore.getConnector() != null) { + deleteDocument(remoteStore.getConnector(), indexName, docId, listener); + } else { + listener.onFailure(new IllegalArgumentException("RemoteStore must have either connectorId or internal connector configured")); + } + } + /** * Deletes a document from remote storage * * @param connectorId The connector ID to use for remote storage * @param indexName The name of the index * @param docId The document ID to delete - * @param client The OpenSearch client * @param listener The action listener */ - public static void deleteDocument( - String connectorId, - String indexName, - String docId, - Client client, - ActionListener listener - ) { + public void deleteDocument(String connectorId, String indexName, String docId, ActionListener listener) { try { // Prepare parameters for connector execution Map parameters = new HashMap<>(); @@ -529,7 +865,33 @@ public static void deleteDocument( parameters.put(INPUT_PARAM, ""); // Execute the connector action with delete_doc action name - executeConnectorAction(connectorId, DELETE_DOC_ACTION, parameters, client, ActionListener.wrap(response -> { + executeConnectorAction(connectorId, DELETE_DOC_ACTION, parameters, ActionListener.wrap(response -> { + log.info("Successfully deleted document from remote index: {}, doc_id: {}", indexName, docId); + XContentParser parser = createParserFromTensorOutput(response); + DeleteResponse deleteResponse = DeleteResponse.fromXContent(parser); + listener.onResponse(deleteResponse); + }, e -> { + log.error("Failed to delete document from remote index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote document delete for index: {}, doc_id: {}", indexName, docId, e); + listener.onFailure(e); + } + } + + public void deleteDocument(Connector connector, String indexName, String docId, ActionListener listener) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(DOC_ID_PARAM, docId); + // input parameter is optional for delete, use empty string as default + parameters.put(INPUT_PARAM, ""); + + // Execute the connector action with delete_doc action name + executeConnectorAction(connector, DELETE_DOC_ACTION, parameters, ActionListener.wrap(response -> { log.info("Successfully deleted document from remote index: {}, doc_id: {}", indexName, docId); XContentParser parser = createParserFromTensorOutput(response); DeleteResponse deleteResponse = DeleteResponse.fromXContent(parser); @@ -548,20 +910,20 @@ public static void deleteDocument( /** * Parses a JSON mapping string to a Map */ - private static Map parseMappingToMap(String mappingJson) throws IOException { + private Map parseMappingToMap(String mappingJson) throws IOException { XContentParser parser = XContentHelper .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, new BytesArray(mappingJson), XContentType.JSON); return parser.mapOrdered(); } - public static XContentParser createParserFromTensorOutput(ModelTensorOutput output) throws IOException { + public XContentParser createParserFromTensorOutput(ModelTensorOutput output) throws IOException { Map dataAsMap = output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); String json = StringUtils.toJson(dataAsMap); XContentParser parser = jsonXContent.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); return parser; } - public static QueryBuilder buildFactSearchQuery( + public QueryBuilder buildFactSearchQuery( MemoryStrategy strategy, String fact, Map namespace, @@ -627,16 +989,9 @@ public static QueryBuilder buildFactSearchQuery( * @param connectorId The connector ID to use for remote storage * @param pipelineName The name of the pipeline to create * @param pipelineBody The pipeline configuration as a JSON string - * @param client The OpenSearch client * @param listener The action listener */ - public static void createRemotePipeline( - String connectorId, - String pipelineName, - String pipelineBody, - Client client, - ActionListener listener - ) { + public void createRemotePipeline(String connectorId, String pipelineName, String pipelineBody, ActionListener listener) { try { // Prepare parameters for connector execution Map parameters = new HashMap<>(); @@ -645,7 +1000,7 @@ public static void createRemotePipeline( parameters.put(CONNECTOR_ACTION_FIELD, "create_ingest_pipeline"); // Execute the connector action - executeConnectorAction(connectorId, CREATE_INGEST_PIPELINE_ACTION, parameters, client, ActionListener.wrap(response -> { + executeConnectorAction(connectorId, CREATE_INGEST_PIPELINE_ACTION, parameters, ActionListener.wrap(response -> { log.info("Successfully created remote pipeline: {}", pipelineName); listener.onResponse(true); }, e -> { @@ -666,17 +1021,13 @@ public static void createRemotePipeline( * @param indexName The name of the index to create * @param pipelineName The name of the pipeline to attach * @param memoryConfig The memory configuration - * @param mlIndicesHandler The ML indices handler - * @param client The OpenSearch client * @param listener The action listener */ - public static void createRemoteLongTermMemoryIndexWithPipeline( + public void createRemoteLongTermMemoryIndexWithPipeline( String connectorId, String indexName, String pipelineName, MemoryConfiguration memoryConfig, - MLIndicesHandler mlIndicesHandler, - Client client, ActionListener listener ) { try { @@ -702,7 +1053,7 @@ public static void createRemoteLongTermMemoryIndexWithPipeline( parameters.put(CONNECTOR_ACTION_FIELD, CREATE_INDEX_ACTION); // Execute the connector action - executeConnectorAction(connectorId, parameters, client, ActionListener.wrap(response -> { + executeConnectorAction(connectorId, CREATE_INDEX_ACTION, parameters, ActionListener.wrap(response -> { log.info("Successfully created remote long-term memory index with pipeline: {}", indexName); listener.onResponse(true); }, e -> { @@ -716,14 +1067,12 @@ public static void createRemoteLongTermMemoryIndexWithPipeline( } } - public static void createRemoteLongTermMemoryIndexWithIngestAndSearchPipeline( + public void createRemoteLongTermMemoryIndexWithIngestAndSearchPipeline( String connectorId, String indexName, String ingestPipelineName, String searchPipelineName, MemoryConfiguration memoryConfig, - MLIndicesHandler mlIndicesHandler, - Client client, ActionListener listener ) { try { @@ -752,7 +1101,55 @@ public static void createRemoteLongTermMemoryIndexWithIngestAndSearchPipeline( parameters.put(CONNECTOR_ACTION_FIELD, CREATE_INDEX_ACTION); // Execute the connector action - executeConnectorAction(connectorId, parameters, client, ActionListener.wrap(response -> { + executeConnectorAction(connectorId, CREATE_INDEX_ACTION, parameters, ActionListener.wrap(response -> { + log.info("Successfully created remote long-term memory index with pipeline: {}", indexName); + listener.onResponse(true); + }, e -> { + log.error("Failed to create remote long-term memory index with pipeline: {}", indexName, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote long-term memory index creation with pipeline for: {}", indexName, e); + listener.onFailure(e); + } + } + + public void createRemoteLongTermMemoryIndexWithIngestAndSearchPipeline( + Connector connector, + String indexName, + String ingestPipelineName, + String searchPipelineName, + MemoryConfiguration memoryConfig, + ActionListener listener + ) { + try { + String indexMapping = buildLongTermMemoryMapping(memoryConfig, mlIndicesHandler); + Map indexSettings = buildLongTermMemorySettings(memoryConfig); + + // Parse the mapping string to a Map + Map mappingMap = parseMappingToMap(indexMapping); + + // Build the request body for creating the index with pipeline + Map requestBody = new HashMap<>(); + requestBody.put("mappings", mappingMap); + + // Add settings with default pipeline (settings already have "index." prefix) + Map settings = new HashMap<>(indexSettings); + settings.put("index.default_pipeline", ingestPipelineName); + if (searchPipelineName != null) { + settings.put("index.search.default_pipeline", searchPipelineName); + } + requestBody.put("settings", settings); + + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(INPUT_PARAM, StringUtils.toJson(requestBody)); + parameters.put(CONNECTOR_ACTION_FIELD, CREATE_INDEX_ACTION); + + // Execute the connector action + executeConnectorAction(connector, CREATE_INDEX_ACTION, parameters, ActionListener.wrap(response -> { log.info("Successfully created remote long-term memory index with pipeline: {}", indexName); listener.onResponse(true); }, e -> { @@ -772,14 +1169,12 @@ public static void createRemoteLongTermMemoryIndexWithIngestAndSearchPipeline( * @param connectorId The connector ID to use for remote storage * @param embeddingModel The embedding model configuration * @param remoteStoreCredential The remote store credentials (used if embedding model doesn't have its own) - * @param client The OpenSearch client * @param listener The action listener that receives the created model ID */ - public static void createRemoteEmbeddingModel( + public void createRemoteEmbeddingModel( String connectorId, org.opensearch.ml.common.memorycontainer.RemoteEmbeddingModel embeddingModel, Map remoteStoreCredential, - Client client, ActionListener listener ) { try { @@ -793,7 +1188,40 @@ public static void createRemoteEmbeddingModel( parameters.put(SKIP_VALIDATE_MISSING_PARAMETERS, "true"); // Execute the connector action with register_model action name - executeConnectorAction(connectorId, REGISTER_MODEL_ACTION, parameters, client, ActionListener.wrap(response -> { + executeConnectorAction(connectorId, REGISTER_MODEL_ACTION, parameters, ActionListener.wrap(response -> { + // Parse model_id from response + String modelId = extractModelIdFromResponse(response); + log.info("Successfully created embedding model in remote store: {}", modelId); + listener.onResponse(modelId); + }, e -> { + log.error("Failed to create embedding model in remote store", e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error building embedding model registration request", e); + listener.onFailure(e); + } + } + + public void createRemoteEmbeddingModel( + Connector connector, + org.opensearch.ml.common.memorycontainer.RemoteEmbeddingModel embeddingModel, + Map remoteStoreCredential, + ActionListener listener + ) { + try { + // Build model registration request body + String requestBody = buildEmbeddingModelRegistrationBody(embeddingModel, remoteStoreCredential); + + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INPUT_PARAM, requestBody); + parameters.put(NO_ESCAPE_PARAMS, INPUT_PARAM); + parameters.put(SKIP_VALIDATE_MISSING_PARAMETERS, "true"); + + // Execute the connector action with register_model action name + executeConnectorAction(connector, REGISTER_MODEL_ACTION, parameters, ActionListener.wrap(response -> { // Parse model_id from response String modelId = extractModelIdFromResponse(response); log.info("Successfully created embedding model in remote store: {}", modelId); @@ -812,7 +1240,7 @@ public static void createRemoteEmbeddingModel( /** * Builds the request body for embedding model registration in remote AOSS */ - private static String buildEmbeddingModelRegistrationBody( + private String buildEmbeddingModelRegistrationBody( org.opensearch.ml.common.memorycontainer.RemoteEmbeddingModel embeddingModel, Map remoteStoreCredential ) { @@ -838,7 +1266,7 @@ private static String buildEmbeddingModelRegistrationBody( /** * Builds Bedrock embedding connector configuration from template */ - private static String buildBedrockEmbeddingConnectorConfig( + private String buildBedrockEmbeddingConnectorConfig( String provider, org.opensearch.ml.common.memorycontainer.RemoteEmbeddingModel embeddingModel, Map remoteStoreCredential @@ -874,7 +1302,7 @@ private static String buildBedrockEmbeddingConnectorConfig( /** * Injects parameters and credential into the connector template */ - private static String injectParametersAndCredential(String template, Map parameters, Map credential) + private String injectParametersAndCredential(String template, Map parameters, Map credential) throws IOException { // Parse template as JSON XContentParser parser = XContentHelper @@ -924,12 +1352,12 @@ private static String injectParametersAndCredential(String template, Map dataAsMap = response.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); Object modelIdObj = dataAsMap.get("model_id"); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index dc47e87059..55cdfc67b5 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -286,7 +286,9 @@ import org.opensearch.ml.engine.tools.WriteToScratchPadTool; import org.opensearch.ml.engine.utils.AgentModelsSearcher; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.helper.MemoryContainerPipelineHelper; import org.opensearch.ml.helper.ModelAccessControlHelper; +import org.opensearch.ml.helper.RemoteMemoryStoreHelper; import org.opensearch.ml.jobs.MLJobParameter; import org.opensearch.ml.jobs.MLJobRunner; import org.opensearch.ml.memory.ConversationalMemoryHandler; @@ -524,6 +526,8 @@ public class MachineLearningPlugin extends Plugin private Encryptor encryptor; private McpToolsHelper mcpToolsHelper; private McpStatelessServerHolder statelessServerHolder; + private RemoteMemoryStoreHelper remoteMemoryStoreHelper; + private MemoryContainerPipelineHelper memoryContainerPipelineHelper; public MachineLearningPlugin() {} @@ -919,6 +923,16 @@ public Collection createComponents( mcpToolsHelper = new McpToolsHelper(client, toolFactoryWrapper); statelessServerHolder = new McpStatelessServerHolder(mcpToolsHelper, client, threadPool); + remoteMemoryStoreHelper = new RemoteMemoryStoreHelper( + client, + clusterService, + scriptService, + xContentRegistry, + encryptor, + mlFeatureEnabledSetting, + mlIndicesHandler + ); + memoryContainerPipelineHelper = new MemoryContainerPipelineHelper(client, mlIndicesHandler, remoteMemoryStoreHelper); return ImmutableList .of( encryptor, @@ -950,7 +964,9 @@ public Collection createComponents( sdkClient, toolFactoryWrapper, mcpToolsHelper, - statelessServerHolder + statelessServerHolder, + remoteMemoryStoreHelper, + memoryContainerPipelineHelper ); } From 9a9a35885739db71adfa90255cc2bf8646ad93f2 Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Wed, 29 Oct 2025 17:35:20 -0700 Subject: [PATCH 15/58] fix: support old style agent register + new style register Signed-off-by: Pavan Yekbote --- .../src/main/java/org/opensearch/ml/common/agent/MLAgent.java | 2 +- .../ml/engine/algorithms/agent/MLAgentExecutor.java | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index 1e2f3509f8..c6c6607de6 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -136,7 +136,7 @@ private void validate() { ); } MLAgentType.from(type); - if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null) { + if (type.equalsIgnoreCase(MLAgentType.CONVERSATIONAL.toString()) && llm == null && model == null) { throw new IllegalArgumentException("We need model information for the conversational agent type"); } Set toolNames = new HashSet<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 3ddfa8b214..168f626091 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -742,6 +742,10 @@ private void updateInteractionWithFailure(String interactionId, Memory memory, S * by the existing agent execution logic. */ private void processAgentInput(AgentMLInput agentMLInput, MLAgent mlAgent) { + // old style agent registration + if (mlAgent.getModel() == null) { + return; + } // If legacy question input is provided, parse to new standard input if (agentMLInput.getInputDataset() != null) { RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) agentMLInput.getInputDataset(); From df354b82d1858e157feb379e79226ea40b12cae3 Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Thu, 30 Oct 2025 13:18:26 -0400 Subject: [PATCH 16/58] fix: memory working with agent Signed-off-by: Pavan Yekbote --- .../ml/engine/algorithms/agent/MLAgentExecutor.java | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 168f626091..327c254f7f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -746,15 +746,14 @@ private void processAgentInput(AgentMLInput agentMLInput, MLAgent mlAgent) { if (mlAgent.getModel() == null) { return; } + // If legacy question input is provided, parse to new standard input if (agentMLInput.getInputDataset() != null) { RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet) agentMLInput.getInputDataset(); - if (!remoteInferenceInputDataSet.getParameters().containsKey(QUESTION)) { - throw new IllegalArgumentException("Question not found in parameters."); + if (remoteInferenceInputDataSet.getParameters().containsKey(QUESTION)) { + AgentInput standardInput = new AgentInput(remoteInferenceInputDataSet.getParameters().get(QUESTION)); + agentMLInput.setAgentInput(standardInput); } - - AgentInput standardInput = new AgentInput(remoteInferenceInputDataSet.getParameters().get(QUESTION)); - agentMLInput.setAgentInput(standardInput); } try { From d66e924ca2438e7b5d8c88ab837ddf9936beaddc Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 31 Oct 2025 18:09:48 -0700 Subject: [PATCH 17/58] support memory container id in execute agent request Signed-off-by: Yaliang Wu --- .../ml/engine/algorithms/agent/AgentUtils.java | 9 ++++++++- .../ml/engine/algorithms/agent/MLAgentExecutor.java | 12 ++++++++++-- .../engine/algorithms/agent/MLChatAgentRunner.java | 3 ++- .../agent/MLConversationalFlowAgentRunner.java | 3 ++- .../agent/MLPlanExecuteAndReflectAgentRunner.java | 9 ++++++++- 5 files changed, 30 insertions(+), 6 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 3df1cb0c66..cbc003388f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -1018,7 +1018,13 @@ public static Tool createTool(Map toolFactories, Map createMemoryParams(String question, String memoryId, String appType, MLAgent mlAgent) { + public static Map createMemoryParams( + String question, + String memoryId, + String appType, + MLAgent mlAgent, + String memoryContainerId + ) { Map memoryParams = new HashMap<>(); memoryParams.put(ConversationIndexMemory.MEMORY_NAME, question); memoryParams.put(ConversationIndexMemory.MEMORY_ID, memoryId); @@ -1026,6 +1032,7 @@ public static Map createMemoryParams(String question, String mem if (mlAgent.getMemory().getMemoryContainerId() != null) { memoryParams.put(MEMORY_CONTAINER_ID_FIELD, mlAgent.getMemory().getMemoryContainerId()); } + memoryParams.putIfAbsent(MEMORY_CONTAINER_ID_FIELD, memoryContainerId); return memoryParams; } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 327c254f7f..87a266a490 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -13,6 +13,7 @@ import static org.opensearch.ml.common.MLTask.RESPONSE_FIELD; import static org.opensearch.ml.common.MLTask.STATE_FIELD; import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.output.model.ModelTensorOutput.INFERENCE_RESULT_FIELD; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE; import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly; @@ -257,7 +258,13 @@ public void execute(Input input, ActionListener listener, TransportChann Memory.Factory> memoryFactory = memoryFactoryMap .get(MLMemoryType.from(memorySpec.getType()).name()); - Map memoryParams = createMemoryParams(question, memoryId, appType, mlAgent); + Map memoryParams = createMemoryParams( + question, + memoryId, + appType, + mlAgent, + inputDataSet.getParameters().get(MEMORY_CONTAINER_ID_FIELD) + ); memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { inputDataSet.getParameters().put(MEMORY_ID, memory.getId()); // get question for regenerate @@ -319,7 +326,8 @@ public void execute(Input input, ActionListener listener, TransportChann question, memoryId, appType, - mlAgent + mlAgent, + inputDataSet.getParameters().get(MEMORY_CONTAINER_ID_FIELD) ); memoryFactory diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 90e2e4dfb8..21b1bc5762 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -7,6 +7,7 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.processTextDoc; import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput; @@ -190,7 +191,7 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener Memory.Factory> memoryFactory = memoryFactoryMap.get(memoryType); - Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent); + Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent, params.get(MEMORY_CONTAINER_ID_FIELD)); memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { // TODO: call runAgent directly if messageHistoryLimit == 0 memory.getMessages(messageHistoryLimit, ActionListener.>wrap(r -> { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index 87e7b761f0..faeec6b050 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID; import static org.opensearch.ml.common.conversation.ActionConstants.PARENT_INTERACTION_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.ToolUtils.TOOL_OUTPUT_FILTERS_FIELD; import static org.opensearch.ml.common.utils.ToolUtils.convertOutputToModelTensor; import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput; @@ -111,7 +112,7 @@ public void run(MLAgent mlAgent, Map params, ActionListener> memoryFactory = memoryFactoryMap.get(memoryType); - Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent); + Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent, params.get(MEMORY_CONTAINER_ID_FIELD)); memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { memory.getMessages(messageHistoryLimit, ActionListener.>wrap(r -> { List messageList = new ArrayList<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 2a3673824c..8d8a854217 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly; import static org.opensearch.ml.common.utils.StringUtils.isJson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE; @@ -293,7 +294,13 @@ public void run(MLAgent mlAgent, Map apiParams, ActionListener> memoryFactory = memoryFactoryMap.get(memoryType); - Map memoryParams = createMemoryParams(apiParams.get(USER_PROMPT_FIELD), memoryId, appType, mlAgent); + Map memoryParams = createMemoryParams( + apiParams.get(USER_PROMPT_FIELD), + memoryId, + appType, + mlAgent, + apiParams.get(MEMORY_CONTAINER_ID_FIELD) + ); memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { memory.getMessages(messageHistoryLimit, ActionListener.>wrap(interactions -> { List completedSteps = new ArrayList<>(); From bb477870602d0c440ddae7ba5e608ed71352f879 Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Tue, 4 Nov 2025 16:17:08 -0800 Subject: [PATCH 18/58] [Feature branch] Introduce hook and context management to OpenSearch Agents (#4397) * add hooks in ml-commons (#4326) Signed-off-by: Xun Zhang * initiate context management api with hook implementation (#4345) * initiate context management api with hook implementation Signed-off-by: Mingshi Liu * apply spotless Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu * Add Context Manager to PER (#4379) * add pre_llm hook to per agent Signed-off-by: Mingshi Liu change context management passing from query parameters to payload Signed-off-by: Mingshi Liu pass hook registery into PER Signed-off-by: Mingshi Liu apply spotless Signed-off-by: Mingshi Liu initiate context management api with hook implementation Signed-off-by: Mingshi Liu * add comment Signed-off-by: Mingshi Liu * format Signed-off-by: Mingshi Liu * add validation Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu * add inner create context management to agent register api Signed-off-by: Mingshi Liu * add code coverage Signed-off-by: Mingshi Liu * allow context management hook register in during agent execute Signed-off-by: Mingshi Liu * add code coverage Signed-off-by: Mingshi Liu * add more code coverage Signed-off-by: Mingshi Liu * add validation check Signed-off-by: Mingshi Liu * adapt to inplace update for context Signed-off-by: Mingshi Liu * fix test Signed-off-by: Mingshi Liu --------- Signed-off-by: Xun Zhang Signed-off-by: Mingshi Liu Co-authored-by: Xun Zhang --- .../org/opensearch/ml/common/CommonValue.java | 2 + .../opensearch/ml/common/agent/MLAgent.java | 105 +- .../common/contextmanager/ActivationRule.java | 26 + .../contextmanager/ActivationRuleFactory.java | 146 ++ .../CharacterBasedTokenCounter.java | 89 + .../ContextManagementTemplate.java | 260 +++ .../common/contextmanager/ContextManager.java | 42 + .../contextmanager/ContextManagerConfig.java | 127 ++ .../contextmanager/ContextManagerContext.java | 180 ++ .../ContextManagerHookProvider.java | 193 ++ .../MessageCountExceedRule.java | 34 + .../common/contextmanager/TokenCounter.java | 45 + .../contextmanager/TokensExceedRule.java | 34 + .../common/contextmanager/package-info.java | 21 + .../common/hooks/EnhancedPostToolEvent.java | 46 + .../ml/common/hooks/HookCallback.java | 23 + .../opensearch/ml/common/hooks/HookEvent.java | 33 + .../ml/common/hooks/HookProvider.java | 20 + .../ml/common/hooks/HookRegistry.java | 93 + .../ml/common/hooks/PostMemoryEvent.java | 50 + .../ml/common/hooks/PostToolEvent.java | 46 + .../ml/common/hooks/PreInvocationEvent.java | 23 + .../ml/common/hooks/PreLLMEvent.java | 37 + .../input/execute/agent/AgentMLInput.java | 17 +- .../agent/MLRegisterAgentRequest.java | 51 + ...CreateContextManagementTemplateAction.java | 17 + ...reateContextManagementTemplateRequest.java | 90 + ...eateContextManagementTemplateResponse.java | 71 + ...DeleteContextManagementTemplateAction.java | 17 + ...eleteContextManagementTemplateRequest.java | 74 + ...leteContextManagementTemplateResponse.java | 71 + .../MLGetContextManagementTemplateAction.java | 17 + ...MLGetContextManagementTemplateRequest.java | 74 + ...LGetContextManagementTemplateResponse.java | 62 + ...LListContextManagementTemplatesAction.java | 17 + ...ListContextManagementTemplatesRequest.java | 74 + ...istContextManagementTemplatesResponse.java | 77 + .../resources/index-mappings/ml_agent.json | 18 + .../ml_context_management_templates.json | 26 + .../ml/common/agent/MLAgentTest.java | 369 +++- .../CharacterBasedTokenCounterTest.java | 164 ++ .../agent/MLAgentGetResponseTest.java | 4 +- .../agent/MLRegisterAgentRequestTest.java | 261 +++ .../ml/engine/agents/AgentContextUtil.java | 182 ++ .../algorithms/agent/MLAgentExecutor.java | 283 ++- .../algorithms/agent/MLChatAgentRunner.java | 116 +- .../MLPlanExecuteAndReflectAgentRunner.java | 52 +- .../contextmanager/SlidingWindowManager.java | 191 ++ .../contextmanager/SummarizationManager.java | 435 +++++ .../ToolsOutputTruncateManager.java | 134 ++ .../algorithms/agent/MLAgentExecutorTest.java | 1663 ++--------------- .../agent/MLChatAgentRunnerTest.java | 8 +- ...LPlanExecuteAndReflectAgentRunnerTest.java | 4 +- .../SlidingWindowManagerTest.java | 234 +++ .../SummarizationManagerTest.java | 326 ++++ plugin/build.gradle | 10 + .../agent/MLAgentRegistrationValidator.java | 261 +++ .../agents/TransportRegisterAgentAction.java | 60 +- .../ContextManagementIndexUtils.java | 96 + .../ContextManagementTemplateService.java | 316 ++++ .../ContextManagerFactory.java | 120 ++ ...textManagementTemplateTransportAction.java | 67 + ...textManagementTemplateTransportAction.java | 67 + ...textManagementTemplateTransportAction.java | 67 + ...extManagementTemplatesTransportAction.java | 63 + .../ml/plugin/MachineLearningPlugin.java | 52 +- .../resources/MLResourceSharingExtension.java | 12 +- ...CreateContextManagementTemplateAction.java | 89 + ...DeleteContextManagementTemplateAction.java | 79 + ...tMLGetContextManagementTemplateAction.java | 78 + ...LListContextManagementTemplatesAction.java | 70 + .../ml/task/MLExecuteTaskRunner.java | 259 ++- .../opensearch/ml/utils/RestActionUtils.java | 1 + .../MLAgentRegistrationValidatorTests.java | 413 ++++ .../DeleteAgentTransportActionTests.java | 2 + .../agents/GetAgentTransportActionTests.java | 2 + .../RegisterAgentTransportActionTests.java | 10 +- .../ContextManagementIndexUtilsTests.java | 231 +++ ...ContextManagementTemplateServiceTests.java | 351 ++++ .../ContextManagerFactoryTests.java | 143 ++ ...anagementTemplateTransportActionTests.java | 196 ++ ...anagementTemplateTransportActionTests.java | 174 ++ ...anagementTemplateTransportActionTests.java | 192 ++ ...nagementTemplatesTransportActionTests.java | 235 +++ ...eContextManagementTemplateActionTests.java | 216 +++ ...eContextManagementTemplateActionTests.java | 181 ++ ...tContextManagementTemplateActionTests.java | 173 ++ ...ContextManagementTemplatesActionTests.java | 184 ++ .../ml/task/MLExecuteTaskRunnerTests.java | 10 +- 89 files changed, 9799 insertions(+), 1555 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRule.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRuleFactory.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounter.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManager.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerConfig.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/MessageCountExceedRule.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/TokenCounter.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/TokensExceedRule.java create mode 100644 common/src/main/java/org/opensearch/ml/common/contextmanager/package-info.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/EnhancedPostToolEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/HookCallback.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/HookEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/HookProvider.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/HookRegistry.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/PostMemoryEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/PostToolEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/PreInvocationEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/hooks/PreLLMEvent.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateResponse.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateResponse.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateResponse.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesAction.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesRequest.java create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesResponse.java create mode 100644 common/src/main/resources/index-mappings/ml_context_management_templates.json create mode 100644 common/src/test/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounterTest.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManagerTest.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtils.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateService.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactory.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidatorTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtilsTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateServiceTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 190a6790f3..e17c8ac468 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -54,6 +54,7 @@ public class CommonValue { public static final String TASK_POLLING_JOB_INDEX = ".ml_commons_task_polling_job"; public static final String MCP_SESSION_MANAGEMENT_INDEX = ".plugins-ml-mcp-session-management"; public static final String MCP_TOOLS_INDEX = ".plugins-ml-mcp-tools"; + public static final String ML_CONTEXT_MANAGEMENT_TEMPLATES_INDEX = ".plugins-ml-context-management-templates"; // index created in 3.1 to track all ml jobs created via job scheduler public static final String ML_JOBS_INDEX = ".plugins-ml-jobs"; public static final Set stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words"); @@ -76,6 +77,7 @@ public class CommonValue { public static final String ML_LONG_MEMORY_HISTORY_INDEX_MAPPING_PATH = "index-mappings/ml_memory_long_term_history.json"; public static final String ML_MCP_SESSION_MANAGEMENT_INDEX_MAPPING_PATH = "index-mappings/ml_mcp_session_management.json"; public static final String ML_MCP_TOOLS_INDEX_MAPPING_PATH = "index-mappings/ml_mcp_tools.json"; + public static final String ML_CONTEXT_MANAGEMENT_TEMPLATES_INDEX_MAPPING_PATH = "index-mappings/ml_context_management_templates.json"; public static final String ML_JOBS_INDEX_MAPPING_PATH = "index-mappings/ml_jobs.json"; public static final String ML_INDEX_INSIGHT_CONFIG_INDEX_MAPPING_PATH = "index-mappings/ml_index_insight_config.json"; public static final String ML_INDEX_INSIGHT_STORAGE_INDEX_MAPPING_PATH = "index-mappings/ml_index_insight_storage.json"; diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index c6c6607de6..06a4d69ebe 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -30,6 +30,7 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; import org.opensearch.telemetry.metrics.tags.Tags; import lombok.Builder; @@ -52,6 +53,8 @@ public class MLAgent implements ToXContentObject, Writeable { public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; public static final String APP_TYPE_FIELD = "app_type"; public static final String IS_HIDDEN_FIELD = "is_hidden"; + public static final String CONTEXT_MANAGEMENT_NAME_FIELD = "context_management_name"; + public static final String CONTEXT_MANAGEMENT_FIELD = "context_management"; private static final String LLM_INTERFACE_FIELD = "_llm_interface"; private static final String TAG_VALUE_UNKNOWN = "unknown"; private static final String TAG_MEMORY_TYPE = "memory_type"; @@ -59,6 +62,7 @@ public class MLAgent implements ToXContentObject, Writeable { public static final int AGENT_NAME_MAX_LENGTH = 128; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT = CommonValue.VERSION_2_13_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_CONTEXT_MANAGEMENT = CommonValue.VERSION_3_3_0; private String name; private String type; @@ -73,6 +77,8 @@ public class MLAgent implements ToXContentObject, Writeable { private Instant lastUpdateTime; private String appType; private Boolean isHidden; + private String contextManagementName; + private ContextManagementTemplate contextManagement; private final String tenantId; @Builder(toBuilder = true) @@ -89,6 +95,8 @@ public MLAgent( Instant lastUpdateTime, String appType, Boolean isHidden, + String contextManagementName, + ContextManagementTemplate contextManagement, String tenantId ) { this.name = name; @@ -104,6 +112,8 @@ public MLAgent( this.appType = appType; // is_hidden field isn't going to be set by user. It will be set by the code. this.isHidden = isHidden; + this.contextManagementName = contextManagementName; + this.contextManagement = contextManagement; this.tenantId = tenantId; validate(); } @@ -123,7 +133,23 @@ public MLAgent( Boolean isHidden, String tenantId ) { - this(name, type, description, llm, null, tools, parameters, memory, createdTime, lastUpdateTime, appType, isHidden, tenantId); + this( + name, + type, + description, + llm, + null, + tools, + parameters, + memory, + createdTime, + lastUpdateTime, + appType, + isHidden, + null, + null, + tenantId + ); } private void validate() { @@ -150,6 +176,17 @@ private void validate() { } } } + validateContextManagement(); + } + + private void validateContextManagement() { + if (contextManagementName != null && contextManagement != null) { + throw new IllegalArgumentException("Cannot specify both context_management_name and context_management"); + } + + if (contextManagement != null && !contextManagement.isValid()) { + throw new IllegalArgumentException("Invalid context management configuration"); + } } private void validateMLAgentType(String agentType) { @@ -196,6 +233,12 @@ public MLAgent(StreamInput input) throws IOException { if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) { isHidden = input.readOptionalBoolean(); } + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CONTEXT_MANAGEMENT)) { + contextManagementName = input.readOptionalString(); + if (input.readBoolean()) { + contextManagement = new ContextManagementTemplate(input); + } + } this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null; validate(); } @@ -245,6 +288,15 @@ public void writeTo(StreamOutput out) throws IOException { if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_HIDDEN_AGENT)) { out.writeOptionalBoolean(isHidden); } + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_CONTEXT_MANAGEMENT)) { + out.writeOptionalString(contextManagementName); + if (contextManagement != null) { + out.writeBoolean(true); + contextManagement.writeTo(out); + } else { + out.writeBoolean(false); + } + } if (streamOutputVersion.onOrAfter(VERSION_2_19_0)) { out.writeOptionalString(tenantId); } @@ -290,6 +342,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (isHidden != null) { builder.field(MLModel.IS_HIDDEN_FIELD, isHidden); } + if (contextManagementName != null) { + builder.field(CONTEXT_MANAGEMENT_NAME_FIELD, contextManagementName); + } + if (contextManagement != null) { + builder.field(CONTEXT_MANAGEMENT_FIELD, contextManagement); + } if (tenantId != null) { builder.field(TENANT_ID_FIELD, tenantId); } @@ -318,6 +376,8 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid Instant lastUpdateTime = null; String appType = null; boolean isHidden = false; + String contextManagementName = null; + ContextManagementTemplate contextManagement = null; String tenantId = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -367,6 +427,12 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid if (parseHidden) isHidden = parser.booleanValue(); break; + case CONTEXT_MANAGEMENT_NAME_FIELD: + contextManagementName = parser.text(); + break; + case CONTEXT_MANAGEMENT_FIELD: + contextManagement = ContextManagementTemplate.parse(parser); + break; case TENANT_ID_FIELD: tenantId = parser.textOrNull(); break; @@ -390,6 +456,8 @@ private static MLAgent parseCommonFields(XContentParser parser, boolean parseHid .lastUpdateTime(lastUpdateTime) .appType(appType) .isHidden(isHidden) + .contextManagementName(contextManagementName) + .contextManagement(contextManagement) .tenantId(tenantId) .build(); } @@ -423,4 +491,39 @@ public Tags getTags() { return tags; } + + /** + * Check if this agent has context management configuration + * @return true if agent has either context management name or inline configuration + */ + public boolean hasContextManagement() { + return contextManagementName != null || contextManagement != null; + } + + /** + * Get the effective context management configuration for this agent. + * This method prioritizes inline configuration over template reference. + * Note: Template resolution requires external service call and should be handled by the caller. + * + * @return the inline context management configuration, or null if using template reference or no configuration + */ + public ContextManagementTemplate getInlineContextManagement() { + return contextManagement; + } + + /** + * Check if this agent uses a context management template reference + * @return true if agent references a context management template by name + */ + public boolean hasContextManagementTemplate() { + return contextManagementName != null; + } + + /** + * Get the context management template name if this agent references one + * @return the template name, or null if no template reference + */ + public String getContextManagementTemplateName() { + return contextManagementName; + } } diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRule.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRule.java new file mode 100644 index 0000000000..c1529e6eda --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRule.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +/** + * Interface for activation rules that determine when a context manager should execute. + * Activation rules evaluate runtime conditions based on the current context state. + */ +public interface ActivationRule { + + /** + * Evaluate whether the activation condition is met. + * @param context the current context state + * @return true if the condition is met and the manager should activate, false otherwise + */ + boolean evaluate(ContextManagerContext context); + + /** + * Get a description of this activation rule for logging and debugging. + * @return a human-readable description of the rule + */ + String getDescription(); +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRuleFactory.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRuleFactory.java new file mode 100644 index 0000000000..f17eb8bc9e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ActivationRuleFactory.java @@ -0,0 +1,146 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import lombok.extern.log4j.Log4j2; + +/** + * Factory class for creating activation rules from configuration. + * Supports creating rules from configuration maps and combining multiple rules. + */ +@Log4j2 +public class ActivationRuleFactory { + + public static final String TOKENS_EXCEED_KEY = "tokens_exceed"; + public static final String MESSAGE_COUNT_EXCEED_KEY = "message_count_exceed"; + + /** + * Create activation rules from a configuration map. + * @param activationConfig the configuration map containing rule definitions + * @return a list of activation rules, or empty list if no valid rules found + */ + public static List createRules(Map activationConfig) { + List rules = new ArrayList<>(); + + if (activationConfig == null || activationConfig.isEmpty()) { + return rules; + } + + // Create tokens_exceed rule + if (activationConfig.containsKey(TOKENS_EXCEED_KEY)) { + try { + Object tokenValue = activationConfig.get(TOKENS_EXCEED_KEY); + int tokenThreshold = parseIntegerValue(tokenValue, TOKENS_EXCEED_KEY); + if (tokenThreshold > 0) { + rules.add(new TokensExceedRule(tokenThreshold)); + log.debug("Created TokensExceedRule with threshold: {}", tokenThreshold); + } else { + log.warn("Invalid token threshold value: {}. Must be positive integer.", tokenValue); + } + } catch (Exception e) { + log.error("Failed to create TokensExceedRule: {}", e.getMessage()); + } + } + + // Create message_count_exceed rule + if (activationConfig.containsKey(MESSAGE_COUNT_EXCEED_KEY)) { + try { + Object messageValue = activationConfig.get(MESSAGE_COUNT_EXCEED_KEY); + int messageThreshold = parseIntegerValue(messageValue, MESSAGE_COUNT_EXCEED_KEY); + if (messageThreshold > 0) { + rules.add(new MessageCountExceedRule(messageThreshold)); + log.debug("Created MessageCountExceedRule with threshold: {}", messageThreshold); + } else { + log.warn("Invalid message count threshold value: {}. Must be positive integer.", messageValue); + } + } catch (Exception e) { + log.error("Failed to create MessageCountExceedRule: {}", e.getMessage()); + } + } + + return rules; + } + + /** + * Create a composite rule that requires ALL rules to be satisfied (AND logic). + * @param rules the list of rules to combine + * @return a composite rule, or null if the list is empty + */ + public static ActivationRule createCompositeRule(List rules) { + if (rules == null || rules.isEmpty()) { + return null; + } + + if (rules.size() == 1) { + return rules.get(0); + } + + return new CompositeActivationRule(rules); + } + + /** + * Parse an integer value from configuration, handling various input types. + * @param value the value to parse + * @param fieldName the field name for error reporting + * @return the parsed integer value + * @throws IllegalArgumentException if the value cannot be parsed + */ + private static int parseIntegerValue(Object value, String fieldName) { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + try { + return Integer.parseInt((String) value); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid integer value for " + fieldName + ": " + value); + } + } else { + throw new IllegalArgumentException("Unsupported value type for " + fieldName + ": " + value.getClass().getSimpleName()); + } + } + + /** + * Composite activation rule that implements AND logic for multiple rules. + */ + private static class CompositeActivationRule implements ActivationRule { + private final List rules; + + public CompositeActivationRule(List rules) { + this.rules = new ArrayList<>(rules); + } + + @Override + public boolean evaluate(ContextManagerContext context) { + // All rules must evaluate to true (AND logic) + for (ActivationRule rule : rules) { + if (!rule.evaluate(context)) { + return false; + } + } + return true; + } + + @Override + public String getDescription() { + StringBuilder sb = new StringBuilder(); + sb.append("composite_rule: ["); + for (int i = 0; i < rules.size(); i++) { + if (i > 0) { + sb.append(" AND "); + } + sb.append(rules.get(i).getDescription()); + } + sb.append("]"); + return sb.toString(); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounter.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounter.java new file mode 100644 index 0000000000..e9b87a20bc --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounter.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import lombok.extern.log4j.Log4j2; + +/** + * Character-based token counter implementation. + * Uses a simple heuristic of approximately 4 characters per token. + * This is a fallback implementation when more sophisticated token counting is not available. + */ +@Log4j2 +public class CharacterBasedTokenCounter implements TokenCounter { + + private static final double CHARS_PER_TOKEN = 4.0; + + @Override + public int count(String text) { + if (text == null || text.isEmpty()) { + return 0; + } + return (int) Math.ceil(text.length() / CHARS_PER_TOKEN); + } + + @Override + public String truncateFromEnd(String text, int maxTokens) { + if (text == null || text.isEmpty()) { + return text; + } + + int currentTokens = count(text); + if (currentTokens <= maxTokens) { + return text; + } + + int maxChars = (int) (maxTokens * CHARS_PER_TOKEN); + if (maxChars >= text.length()) { + return text; + } + + return text.substring(0, maxChars); + } + + @Override + public String truncateFromBeginning(String text, int maxTokens) { + if (text == null || text.isEmpty()) { + return text; + } + + int currentTokens = count(text); + if (currentTokens <= maxTokens) { + return text; + } + + int maxChars = (int) (maxTokens * CHARS_PER_TOKEN); + if (maxChars >= text.length()) { + return text; + } + + return text.substring(text.length() - maxChars); + } + + @Override + public String truncateMiddle(String text, int maxTokens) { + if (text == null || text.isEmpty()) { + return text; + } + + int currentTokens = count(text); + if (currentTokens <= maxTokens) { + return text; + } + + int maxChars = (int) (maxTokens * CHARS_PER_TOKEN); + if (maxChars >= text.length()) { + return text; + } + + // Keep equal portions from beginning and end + int halfChars = maxChars / 2; + String beginning = text.substring(0, halfChars); + String end = text.substring(text.length() - halfChars); + + return beginning + end; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java new file mode 100644 index 0000000000..40969b8c9a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java @@ -0,0 +1,260 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Context Management Template defines which context managers to use and when. + * This class represents a registered configuration that can be applied to + * agent execution to enable dynamic context optimization. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder(toBuilder = true) +public class ContextManagementTemplate implements ToXContentObject, Writeable { + + public static final String NAME_FIELD = "name"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String HOOKS_FIELD = "hooks"; + public static final String CREATED_TIME_FIELD = "created_time"; + public static final String LAST_MODIFIED_FIELD = "last_modified"; + public static final String CREATED_BY_FIELD = "created_by"; + + /** + * Unique name for the context management template + */ + private String name; + + /** + * Human-readable description of what this template does + */ + private String description; + + /** + * Map of hook names to lists of context manager configurations + */ + private Map> hooks; + + /** + * When this template was created + */ + private Instant createdTime; + + /** + * When this template was last modified + */ + private Instant lastModified; + + /** + * Who created this template + */ + private String createdBy; + + /** + * Constructor from StreamInput + */ + public ContextManagementTemplate(StreamInput input) throws IOException { + this.name = input.readString(); + this.description = input.readOptionalString(); + + // Read hooks map + int hooksSize = input.readInt(); + if (hooksSize > 0) { + this.hooks = input.readMap(StreamInput::readString, in -> { + try { + int listSize = in.readInt(); + List configs = new java.util.ArrayList<>(); + for (int i = 0; i < listSize; i++) { + configs.add(new ContextManagerConfig(in)); + } + return configs; + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + this.createdTime = input.readOptionalInstant(); + this.lastModified = input.readOptionalInstant(); + this.createdBy = input.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeOptionalString(description); + + // Write hooks map + if (hooks != null) { + out.writeInt(hooks.size()); + out.writeMap(hooks, StreamOutput::writeString, (output, configs) -> { + try { + output.writeInt(configs.size()); + for (ContextManagerConfig config : configs) { + config.writeTo(output); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } else { + out.writeInt(0); + } + + out.writeOptionalInstant(createdTime); + out.writeOptionalInstant(lastModified); + out.writeOptionalString(createdBy); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (name != null) { + builder.field(NAME_FIELD, name); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD, description); + } + if (hooks != null && !hooks.isEmpty()) { + builder.field(HOOKS_FIELD, hooks); + } + if (createdTime != null) { + builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); + } + if (lastModified != null) { + builder.field(LAST_MODIFIED_FIELD, lastModified.toEpochMilli()); + } + if (createdBy != null) { + builder.field(CREATED_BY_FIELD, createdBy); + } + + builder.endObject(); + return builder; + } + + /** + * Parse ContextManagementTemplate from XContentParser + */ + public static ContextManagementTemplate parse(XContentParser parser) throws IOException { + String name = null; + String description = null; + Map> hooks = null; + Instant createdTime = null; + Instant lastModified = null; + String createdBy = null; + + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case NAME_FIELD: + name = parser.text(); + break; + case DESCRIPTION_FIELD: + description = parser.text(); + break; + case HOOKS_FIELD: + hooks = parseHooks(parser); + break; + case CREATED_TIME_FIELD: + createdTime = Instant.ofEpochMilli(parser.longValue()); + break; + case LAST_MODIFIED_FIELD: + lastModified = Instant.ofEpochMilli(parser.longValue()); + break; + case CREATED_BY_FIELD: + createdBy = parser.text(); + break; + default: + parser.skipChildren(); + break; + } + } + + return ContextManagementTemplate + .builder() + .name(name) + .description(description) + .hooks(hooks) + .createdTime(createdTime) + .lastModified(lastModified) + .createdBy(createdBy) + .build(); + } + + /** + * Parse hooks configuration from XContentParser + */ + private static Map> parseHooks(XContentParser parser) throws IOException { + Map> hooks = new java.util.HashMap<>(); + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String hookName = parser.currentName(); + parser.nextToken(); // Move to START_ARRAY + + List configs = new java.util.ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + configs.add(ContextManagerConfig.parse(parser)); + } + + hooks.put(hookName, configs); + } + + return hooks; + } + + /** + * Validate the template configuration + */ + public boolean isValid() { + if (name == null || name.trim().isEmpty()) { + return false; + } + + // Allow null hooks (no context management) but not empty hooks map (misconfiguration) + if (hooks != null) { + if (hooks.isEmpty()) { + return false; + } + + // Validate all context manager configs + for (List configs : hooks.values()) { + if (configs != null) { + for (ContextManagerConfig config : configs) { + if (!config.isValid()) { + return false; + } + } + } + } + } + + return true; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManager.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManager.java new file mode 100644 index 0000000000..325f98900a --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManager.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.Map; + +/** + * Base interface for all context managers. + * Context managers are pluggable components that inspect and transform + * agent context components before they are sent to an LLM. + */ +public interface ContextManager { + + /** + * Get the type identifier for this context manager + * @return String identifying the manager type + */ + String getType(); + + /** + * Initialize the context manager with configuration + * @param config Configuration map for the manager + */ + void initialize(Map config); + + /** + * Check if this context manager should activate based on current context + * @param context The current context manager context + * @return true if the manager should execute, false otherwise + */ + boolean shouldActivate(ContextManagerContext context); + + /** + * Execute the context transformation + * @param context The context manager context to transform + */ + void execute(ContextManagerContext context); + +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerConfig.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerConfig.java new file mode 100644 index 0000000000..92755cb243 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerConfig.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Configuration for a context manager within a context management template. + * This class holds the configuration details for how a specific context manager + * should be configured and when it should activate. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class ContextManagerConfig implements ToXContentObject, Writeable { + + public static final String TYPE_FIELD = "type"; + public static final String ACTIVATION_FIELD = "activation"; + public static final String CONFIG_FIELD = "config"; + + /** + * The type of context manager (e.g., "ToolsOutputTruncateManager") + */ + private String type; + + /** + * Activation conditions that determine when this manager should execute + */ + private Map activation; + + /** + * Configuration parameters specific to this manager type + */ + private Map config; + + /** + * Constructor from StreamInput + */ + public ContextManagerConfig(StreamInput input) throws IOException { + this.type = input.readString(); + this.activation = input.readMap(); + this.config = input.readMap(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + out.writeMap(activation); + out.writeMap(config); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (type != null) { + builder.field(TYPE_FIELD, type); + } + if (activation != null && !activation.isEmpty()) { + builder.field(ACTIVATION_FIELD, activation); + } + if (config != null && !config.isEmpty()) { + builder.field(CONFIG_FIELD, config); + } + + builder.endObject(); + return builder; + } + + /** + * Parse ContextManagerConfig from XContentParser + */ + public static ContextManagerConfig parse(XContentParser parser) throws IOException { + String type = null; + Map activation = null; + Map config = null; + + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + parser.nextToken(); + } + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case TYPE_FIELD: + type = parser.text(); + break; + case ACTIVATION_FIELD: + activation = parser.map(); + break; + case CONFIG_FIELD: + config = parser.map(); + break; + default: + parser.skipChildren(); + break; + } + } + + return new ContextManagerConfig(type, activation, config); + } + + /** + * Validate the configuration + * @return true if configuration is valid, false otherwise + */ + public boolean isValid() { + return type != null && !type.trim().isEmpty(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java new file mode 100644 index 0000000000..9854b78dba --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java @@ -0,0 +1,180 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.conversation.Interaction; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Context object that contains all components of the agent execution context. + * This object is passed to context managers for inspection and transformation. + */ +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class ContextManagerContext { + + /** + * The invocation state from the hook system + */ + private Map invocationState; + + /** + * The system prompt for the LLM + */ + private String systemPrompt; + + /** + * The chat history as a list of interactions + */ + @Builder.Default + private List chatHistory = new ArrayList<>(); + + /** + * The current user prompt/input + */ + private String userPrompt; + + /** + * The tool configurations available to the agent + */ + @Builder.Default + private List toolConfigs = new ArrayList<>(); + + /** + * The tool interactions/results from tool executions + */ + @Builder.Default + private List toolInteractions = new ArrayList<>(); + + /** + * Additional parameters for context processing + */ + @Builder.Default + private Map parameters = new HashMap<>(); + + /** + * Get the total token count for the current context. + * This is a utility method that can be used by context managers. + * @return estimated token count + */ + public int getEstimatedTokenCount() { + int tokenCount = 0; + + // Estimate tokens for system prompt + if (systemPrompt != null) { + tokenCount += estimateTokens(systemPrompt); + } + + // Estimate tokens for user prompt + if (userPrompt != null) { + tokenCount += estimateTokens(userPrompt); + } + + // Estimate tokens for chat history + for (Interaction interaction : chatHistory) { + if (interaction.getInput() != null) { + tokenCount += estimateTokens(interaction.getInput()); + } + if (interaction.getResponse() != null) { + tokenCount += estimateTokens(interaction.getResponse()); + } + } + + // Estimate tokens for tool interactions + for (String interaction : toolInteractions) { + tokenCount += estimateTokens(interaction); + } + + return tokenCount; + } + + /** + * Get the message count in chat history. + * @return number of messages in chat history + */ + public int getMessageCount() { + return chatHistory.size(); + } + + /** + * Simple token estimation based on character count. + * This is a fallback method - more sophisticated token counting should be implemented + * in dedicated TokenCounter implementations. + * @param text the text to estimate tokens for + * @return estimated token count + */ + private int estimateTokens(String text) { + if (text == null || text.isEmpty()) { + return 0; + } + // Rough estimation: 1 token per 4 characters + return (int) Math.ceil(text.length() / 4.0); + } + + /** + * Add a tool interaction to the context. + * @param interaction the tool interaction to add + */ + public void addToolInteraction(String interaction) { + if (toolInteractions == null) { + toolInteractions = new ArrayList<>(); + } + toolInteractions.add(interaction); + } + + /** + * Add an interaction to the chat history. + * @param interaction the interaction to add + */ + public void addChatHistoryInteraction(Interaction interaction) { + if (chatHistory == null) { + chatHistory = new ArrayList<>(); + } + chatHistory.add(interaction); + } + + /** + * Clear the chat history. + */ + public void clearChatHistory() { + if (chatHistory != null) { + chatHistory.clear(); + } + } + + /** + * Get a parameter value by key. + * @param key the parameter key + * @return the parameter value, or null if not found + */ + public Object getParameter(String key) { + return parameters != null ? parameters.get(key) : null; + } + + /** + * Set a parameter value. + * @param key the parameter key + * @param value the parameter value + */ + public void setParameter(String key, String value) { + if (parameters == null) { + parameters = new HashMap<>(); + } + parameters.put(key, value); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java new file mode 100644 index 0000000000..35109c53dd --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java @@ -0,0 +1,193 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.hooks.EnhancedPostToolEvent; +import org.opensearch.ml.common.hooks.HookProvider; +import org.opensearch.ml.common.hooks.HookRegistry; +import org.opensearch.ml.common.hooks.PostMemoryEvent; +import org.opensearch.ml.common.hooks.PreLLMEvent; + +import lombok.extern.log4j.Log4j2; + +/** + * Hook provider that integrates context managers with the hook registry. + * This class manages the execution of context managers based on hook events. + */ +@Log4j2 +public class ContextManagerHookProvider implements HookProvider { + private final List contextManagers; + private final Map> hookToManagersMap; + + /** + * Constructor for ContextManagerHookProvider + * @param contextManagers List of context managers to register + */ + public ContextManagerHookProvider(List contextManagers) { + this.contextManagers = new ArrayList<>(contextManagers); + this.hookToManagersMap = new HashMap<>(); + + // Group managers by hook type based on their configuration + // This would typically be done based on the template configuration + // For now, we'll organize them by common hook types + organizeManagersByHook(); + } + + /** + * Register hook callbacks with the provided registry + * @param registry The HookRegistry to register callbacks with + */ + @Override + public void registerHooks(HookRegistry registry) { + // Register callbacks for each hook type + registry.addCallback(PreLLMEvent.class, this::handlePreLLM); + registry.addCallback(EnhancedPostToolEvent.class, this::handlePostTool); + registry.addCallback(PostMemoryEvent.class, this::handlePostMemory); + + log.info("Registered context manager hooks for {} managers", contextManagers.size()); + } + + /** + * Handle PreLLM hook events + * @param event The PreLLM event + */ + private void handlePreLLM(PreLLMEvent event) { + log.debug("Handling PreLLM event"); + executeManagersForHook("PRE_LLM", event.getContext()); + } + + /** + * Handle PostTool hook events + * @param event The EnhancedPostTool event + */ + private void handlePostTool(EnhancedPostToolEvent event) { + log.debug("Handling PostTool event"); + executeManagersForHook("POST_TOOL", event.getContext()); + } + + /** + * Handle PostMemory hook events + * @param event The PostMemory event + */ + private void handlePostMemory(PostMemoryEvent event) { + log.debug("Handling PostMemory event"); + executeManagersForHook("POST_MEMORY", event.getContext()); + } + + /** + * Execute context managers for a specific hook + * @param hookName The name of the hook + * @param context The context manager context + */ + private void executeManagersForHook(String hookName, ContextManagerContext context) { + List managers = hookToManagersMap.get(hookName); + if (managers != null && !managers.isEmpty()) { + log.debug("Executing {} context managers for hook: {}", managers.size(), hookName); + + for (ContextManager manager : managers) { + try { + if (manager.shouldActivate(context)) { + log.debug("Executing context manager: {}", manager.getType()); + manager.execute(context); + log.debug("Successfully executed context manager: {}", manager.getType()); + } else { + log.debug("Context manager {} activation conditions not met, skipping", manager.getType()); + } + } catch (Exception e) { + log.error("Context manager {} failed: {}", manager.getType(), e.getMessage(), e); + // Continue with other managers even if one fails + } + } + } else { + log.debug("No context managers registered for hook: {}", hookName); + } + } + + /** + * Organize managers by hook type + * This is a simplified implementation - in practice, this would be based on + * the context management template configuration + */ + private void organizeManagersByHook() { + // For now, we'll assign managers to hooks based on their type + // This would be replaced with actual template-based configuration + for (ContextManager manager : contextManagers) { + String managerType = manager.getType(); + + // Assign managers to appropriate hooks based on their type + if ("ToolsOutputTruncateManager".equals(managerType)) { + addManagerToHook("POST_TOOL", manager); + } else if ("SlidingWindowManager".equals(managerType) || "SummarizingManager".equals(managerType)) { + addManagerToHook("POST_MEMORY", manager); + addManagerToHook("PRE_LLM", manager); + } else if ("SystemPromptAugmentationManager".equals(managerType)) { + addManagerToHook("PRE_LLM", manager); + } else { + // Default to PRE_LLM for unknown types + addManagerToHook("PRE_LLM", manager); + } + } + } + + /** + * Add a manager to a specific hook + * @param hookName The hook name + * @param manager The context manager + */ + private void addManagerToHook(String hookName, ContextManager manager) { + hookToManagersMap.computeIfAbsent(hookName, k -> new ArrayList<>()).add(manager); + log.debug("Added manager {} to hook {}", manager.getType(), hookName); + } + + /** + * Update the hook-to-managers mapping based on template configuration + * @param hookConfiguration Map of hook names to manager configurations + */ + public void updateHookConfiguration(Map> hookConfiguration) { + hookToManagersMap.clear(); + + for (Map.Entry> entry : hookConfiguration.entrySet()) { + String hookName = entry.getKey(); + List configs = entry.getValue(); + + for (ContextManagerConfig config : configs) { + // Find the corresponding context manager + ContextManager manager = findManagerByType(config.getType()); + if (manager != null) { + addManagerToHook(hookName, manager); + } else { + log.warn("Context manager of type {} not found", config.getType()); + } + } + } + + log.info("Updated hook configuration with {} hooks", hookConfiguration.size()); + } + + /** + * Find a context manager by its type + * @param type The manager type + * @return The context manager or null if not found + */ + private ContextManager findManagerByType(String type) { + return contextManagers.stream().filter(manager -> type.equals(manager.getType())).findFirst().orElse(null); + } + + /** + * Get the number of managers registered for a specific hook + * @param hookName The hook name + * @return Number of managers + */ + public int getManagerCount(String hookName) { + List managers = hookToManagersMap.get(hookName); + return managers != null ? managers.size() : 0; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/MessageCountExceedRule.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/MessageCountExceedRule.java new file mode 100644 index 0000000000..f3a4d5c57e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/MessageCountExceedRule.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Activation rule that triggers when the chat history message count exceeds a specified threshold. + */ +@AllArgsConstructor +@Getter +public class MessageCountExceedRule implements ActivationRule { + + private final int messageThreshold; + + @Override + public boolean evaluate(ContextManagerContext context) { + if (context == null) { + return false; + } + + int currentMessageCount = context.getMessageCount(); + return currentMessageCount > messageThreshold; + } + + @Override + public String getDescription() { + return "message_count_exceed: " + messageThreshold; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/TokenCounter.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokenCounter.java new file mode 100644 index 0000000000..42bbd813ee --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokenCounter.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +/** + * Interface for counting and truncating tokens in text. + * Provides methods for accurate token counting and various truncation strategies. + */ +public interface TokenCounter { + + /** + * Count the number of tokens in the given text. + * @param text the text to count tokens for + * @return the number of tokens + */ + int count(String text); + + /** + * Truncate text from the end to fit within the specified token limit. + * @param text the text to truncate + * @param maxTokens the maximum number of tokens to keep + * @return the truncated text + */ + String truncateFromEnd(String text, int maxTokens); + + /** + * Truncate text from the beginning to fit within the specified token limit. + * @param text the text to truncate + * @param maxTokens the maximum number of tokens to keep + * @return the truncated text + */ + String truncateFromBeginning(String text, int maxTokens); + + /** + * Truncate text from the middle to fit within the specified token limit. + * Preserves both beginning and end portions of the text. + * @param text the text to truncate + * @param maxTokens the maximum number of tokens to keep + * @return the truncated text + */ + String truncateMiddle(String text, int maxTokens); +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/TokensExceedRule.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokensExceedRule.java new file mode 100644 index 0000000000..e4bc0544f0 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/TokensExceedRule.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Activation rule that triggers when the context token count exceeds a specified threshold. + */ +@AllArgsConstructor +@Getter +public class TokensExceedRule implements ActivationRule { + + private final int tokenThreshold; + + @Override + public boolean evaluate(ContextManagerContext context) { + if (context == null) { + return false; + } + + int currentTokenCount = context.getEstimatedTokenCount(); + return currentTokenCount > tokenThreshold; + } + + @Override + public String getDescription() { + return "tokens_exceed: " + tokenThreshold; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/package-info.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/package-info.java new file mode 100644 index 0000000000..b8d8cb5cc4 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Context management framework for OpenSearch ML-Commons. + * + * This package provides a pluggable context management system that allows for dynamic + * optimization of LLM context windows through configurable context managers. + * + * Key components: + * - {@link org.opensearch.ml.common.contextmanager.ContextManager}: Base interface for all context managers + * - {@link org.opensearch.ml.common.contextmanager.ContextManagerContext}: Context object containing all agent execution state + * - {@link org.opensearch.ml.common.contextmanager.ActivationRule}: Interface for rules that determine when managers should execute + * - {@link org.opensearch.ml.common.contextmanager.ActivationRuleFactory}: Factory for creating activation rules from configuration + * + * The system integrates with the existing hook framework to provide seamless context optimization + * during agent execution without breaking existing functionality. + */ +package org.opensearch.ml.common.contextmanager; diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/EnhancedPostToolEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/EnhancedPostToolEvent.java new file mode 100644 index 0000000000..7db6341c9e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/EnhancedPostToolEvent.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +/** + * Enhanced version of PostToolEvent that includes context manager context. + * This event is triggered after tool execution and provides access to both + * tool results and the full context, allowing context managers to modify + * tool outputs and other context components. + */ +public class EnhancedPostToolEvent extends PostToolEvent { + private final ContextManagerContext context; + + /** + * Constructor for EnhancedPostToolEvent + * @param toolResults List of tool execution results + * @param error Exception that occurred during tool execution, null if successful + * @param context The context manager context containing all context components + * @param invocationState The current state of the agent invocation + */ + public EnhancedPostToolEvent( + List> toolResults, + Exception error, + ContextManagerContext context, + Map invocationState + ) { + super(toolResults, error, invocationState); + this.context = context; + } + + /** + * Get the context manager context + * @return ContextManagerContext containing all context components + */ + public ContextManagerContext getContext() { + return context; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookCallback.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookCallback.java new file mode 100644 index 0000000000..13e7299e01 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookCallback.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +/** + * Functional interface for handling specific hook events. + * Implementations of this interface define the behavior to execute + * when a particular hook event is triggered. + * + * @param The type of HookEvent this callback handles + */ +@FunctionalInterface +public interface HookCallback { + + /** + * Handle the hook event + * @param event The hook event to handle + */ + void handle(T event); +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookEvent.java new file mode 100644 index 0000000000..c7f1503b61 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookEvent.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.Map; + +/** + * Base class for all hook events in the ML agent lifecycle. + * Hook events are strongly-typed events that carry context information + * for different stages of agent execution. + */ +public abstract class HookEvent { + private final Map invocationState; + + /** + * Constructor for HookEvent + * @param invocationState The current state of the agent invocation + */ + protected HookEvent(Map invocationState) { + this.invocationState = invocationState; + } + + /** + * Get the invocation state + * @return Map containing the current invocation state + */ + public Map getInvocationState() { + return invocationState; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookProvider.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookProvider.java new file mode 100644 index 0000000000..d6612f6749 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookProvider.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +/** + * Interface for providers that register hook callbacks with the HookRegistry. + * Implementations of this interface define which hooks they want to listen to + * and provide the callback implementations. + */ +public interface HookProvider { + + /** + * Register hook callbacks with the provided registry + * @param registry The HookRegistry to register callbacks with + */ + void registerHooks(HookRegistry registry); +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/HookRegistry.java b/common/src/main/java/org/opensearch/ml/common/hooks/HookRegistry.java new file mode 100644 index 0000000000..32076d0d78 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/HookRegistry.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import lombok.extern.log4j.Log4j2; + +/** + * Registry for managing hook callbacks and event emission. + * This class manages the registration of callbacks for different hook event types + * and provides methods to emit events to registered callbacks. + */ +@Log4j2 +public class HookRegistry { + private final Map, List>> callbacks; + + /** + * Constructor for HookRegistry + */ + public HookRegistry() { + this.callbacks = new ConcurrentHashMap<>(); + } + + /** + * Add a callback for a specific hook event type + * @param eventType The class of the hook event + * @param callback The callback to execute when the event is emitted + * @param The type of hook event + */ + public void addCallback(Class eventType, HookCallback callback) { + callbacks.computeIfAbsent(eventType, k -> new ArrayList<>()).add(callback); + log.debug("Registered callback for event type: {}", eventType.getSimpleName()); + } + + /** + * Emit an event to all registered callbacks for that event type + * @param event The hook event to emit + * @param The type of hook event + */ + @SuppressWarnings("unchecked") + public void emit(T event) { + Class eventType = event.getClass(); + List> eventCallbacks = callbacks.get(eventType); + + log + .info( + "HookRegistry.emit() called for event type: {}, callbacks available: {}", + eventType.getSimpleName(), + eventCallbacks != null ? eventCallbacks.size() : 0 + ); + + if (eventCallbacks != null) { + log.info("Emitting {} event to {} callbacks", eventType.getSimpleName(), eventCallbacks.size()); + + for (HookCallback callback : eventCallbacks) { + try { + log.info("Executing callback: {}", callback.getClass().getSimpleName()); + ((HookCallback) callback).handle(event); + } catch (Exception e) { + log.error("Error executing hook callback for event type {}: {}", eventType.getSimpleName(), e.getMessage(), e); + // Continue with other callbacks even if one fails + } + } + } else { + log.warn("No callbacks registered for event type: {}", eventType.getSimpleName()); + } + } + + /** + * Get the number of registered callbacks for a specific event type + * @param eventType The class of the hook event + * @return Number of registered callbacks + */ + public int getCallbackCount(Class eventType) { + List> eventCallbacks = callbacks.get(eventType); + return eventCallbacks != null ? eventCallbacks.size() : 0; + } + + /** + * Clear all registered callbacks + */ + public void clear() { + callbacks.clear(); + log.debug("Cleared all hook callbacks"); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PostMemoryEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PostMemoryEvent.java new file mode 100644 index 0000000000..006f6e8069 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PostMemoryEvent.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.conversation.Interaction; + +/** + * Hook event triggered after memory retrieval in the agent lifecycle. + * This event provides access to the retrieved chat history and context, + * allowing context managers to modify the memory before it's used. + */ +public class PostMemoryEvent extends HookEvent { + private final ContextManagerContext context; + private final List retrievedHistory; + + /** + * Constructor for PostMemoryEvent + * @param context The context manager context containing all context components + * @param retrievedHistory The chat history retrieved from memory + * @param invocationState The current state of the agent invocation + */ + public PostMemoryEvent(ContextManagerContext context, List retrievedHistory, Map invocationState) { + super(invocationState); + this.context = context; + this.retrievedHistory = retrievedHistory; + } + + /** + * Get the context manager context + * @return ContextManagerContext containing all context components + */ + public ContextManagerContext getContext() { + return context; + } + + /** + * Get the retrieved chat history + * @return List of interactions retrieved from memory + */ + public List getRetrievedHistory() { + return retrievedHistory; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PostToolEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PostToolEvent.java new file mode 100644 index 0000000000..609d6028da --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PostToolEvent.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.List; +import java.util.Map; + +/** + * Hook event triggered after tool execution in the agent lifecycle. + * This event provides access to tool results and any errors that occurred. + */ +public class PostToolEvent extends HookEvent { + private final List> toolResults; + private final Exception error; + + /** + * Constructor for PostToolEvent + * @param toolResults List of tool execution results + * @param error Exception that occurred during tool execution, null if successful + * @param invocationState The current state of the agent invocation + */ + public PostToolEvent(List> toolResults, Exception error, Map invocationState) { + super(invocationState); + this.toolResults = toolResults; + this.error = error; + } + + /** + * Get the tool execution results + * @return List of tool results + */ + public List> getToolResults() { + return toolResults; + } + + /** + * Get the error that occurred during tool execution + * @return Exception if an error occurred, null otherwise + */ + public Exception getError() { + return error; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PreInvocationEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PreInvocationEvent.java new file mode 100644 index 0000000000..42e0cc1d0c --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PreInvocationEvent.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.Map; + +import org.opensearch.ml.common.input.Input; + +public class PreInvocationEvent extends HookEvent { + private final Input input; + + public PreInvocationEvent(Input input, Map invocationState) { + super(invocationState); + this.input = input; + } + + public Input getInput() { + return input; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/hooks/PreLLMEvent.java b/common/src/main/java/org/opensearch/ml/common/hooks/PreLLMEvent.java new file mode 100644 index 0000000000..1b82b04512 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/hooks/PreLLMEvent.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.hooks; + +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +/** + * Hook event triggered before LLM invocation in the agent lifecycle. + * This event provides access to the context that will be sent to the LLM, + * allowing context managers to modify it before the LLM call. + */ +public class PreLLMEvent extends HookEvent { + private final ContextManagerContext context; + + /** + * Constructor for PreLLMEvent + * @param context The context manager context containing all context components + * @param invocationState The current state of the agent invocation + */ + public PreLLMEvent(ContextManagerContext context, Map invocationState) { + super(invocationState); + this.context = context; + } + + /** + * Get the context manager context + * @return ContextManagerContext containing all context components + */ + public ContextManagerContext getContext() { + return context; + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java index 48a24bd324..71368ff23c 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/agent/AgentMLInput.java @@ -22,6 +22,7 @@ import org.opensearch.ml.common.agent.AgentInput; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.utils.StringUtils; @@ -54,6 +55,14 @@ public class AgentMLInput extends MLInput { @Setter private AgentInput agentInput; + @Getter + @Setter + private HookRegistry hookRegistry; + + @Getter + @Setter + private String contextManagementName; + @Builder(builderMethodName = "AgentMLInputBuilder") public AgentMLInput(String agentId, String tenantId, FunctionName functionName, MLInputDataset inputDataset) { this(agentId, tenantId, functionName, inputDataset, false); @@ -105,6 +114,7 @@ public void writeTo(StreamOutput out) throws IOException { agentInput.writeTo(out); } } + // Note: contextManagementName and hookRegistry are not serialized as they are runtime-only fields } public AgentMLInput(StreamInput in) throws IOException { @@ -139,7 +149,12 @@ public AgentMLInput(XContentParser parser, FunctionName functionName) throws IOE break; case PARAMETERS_FIELD: // Legacy format - parse parameters into RemoteInferenceInputDataSet - Map parameters = StringUtils.getParameterMap(parser.map()); + Map parameterObjs = parser.map(); + Map parameters = StringUtils.getParameterMap(parameterObjs); + // Extract context_management from parameters + if (parameterObjs.containsKey("context_management")) { + contextManagementName = (String) parameterObjs.get("context_management"); + } inputDataset = new RemoteInferenceInputDataSet(parameters); break; case INPUT_FIELD: diff --git a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java index c73f2150aa..90096044f0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java @@ -48,11 +48,62 @@ public ActionRequestValidationException validate() { ActionRequestValidationException exception = null; if (mlAgent == null) { exception = addValidationError("ML agent can't be null", exception); + } else { + // Basic validation - check for conflicting configuration (following connector pattern) + if (mlAgent.getContextManagementName() != null && mlAgent.getContextManagement() != null) { + exception = addValidationError("Cannot specify both context_management_name and context_management", exception); + } + + // Validate context management template name + if (mlAgent.getContextManagementName() != null) { + exception = validateContextManagementTemplateName(mlAgent.getContextManagementName(), exception); + } + + // Validate inline context management configuration + if (mlAgent.getContextManagement() != null) { + exception = validateInlineContextManagement(mlAgent.getContextManagement(), exception); + } + } + + return exception; + } + + private ActionRequestValidationException validateContextManagementTemplateName( + String templateName, + ActionRequestValidationException exception + ) { + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Context management template name cannot be null or empty", exception); + } else if (templateName.length() > 256) { + exception = addValidationError("Context management template name cannot exceed 256 characters", exception); + } else if (!templateName.matches("^[a-zA-Z0-9._-]+$")) { + exception = addValidationError( + "Context management template name can only contain letters, numbers, underscores, hyphens, and dots", + exception + ); } + return exception; + } + private ActionRequestValidationException validateInlineContextManagement( + org.opensearch.ml.common.contextmanager.ContextManagementTemplate contextManagement, + ActionRequestValidationException exception + ) { + if (contextManagement.getHooks() != null) { + for (String hookName : contextManagement.getHooks().keySet()) { + if (!isValidHookName(hookName)) { + exception = addValidationError("Invalid hook name: " + hookName, exception); + } + } + } return exception; } + private boolean isValidHookName(String hookName) { + // Define valid hook names based on the system's supported hooks + return hookName.equals("POST_TOOL") || hookName.equals("PRE_LLM") || hookName.equals("PRE_TOOL") || hookName.equals("POST_LLM"); + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateAction.java new file mode 100644 index 0000000000..b6116afa4f --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLCreateContextManagementTemplateAction extends ActionType { + public static MLCreateContextManagementTemplateAction INSTANCE = new MLCreateContextManagementTemplateAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/create"; + + private MLCreateContextManagementTemplateAction() { + super(NAME, MLCreateContextManagementTemplateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateRequest.java new file mode 100644 index 0000000000..ee98607505 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateRequest.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLCreateContextManagementTemplateRequest extends ActionRequest { + + String templateName; + ContextManagementTemplate template; + + @Builder + public MLCreateContextManagementTemplateRequest(String templateName, ContextManagementTemplate template) { + this.templateName = templateName; + this.template = template; + } + + public MLCreateContextManagementTemplateRequest(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + this.template = new ContextManagementTemplate(in); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Template name cannot be null or empty", exception); + } + if (template == null) { + exception = addValidationError("Context management template cannot be null", exception); + } else { + // Validate template structure + if (template.getName() == null || template.getName().trim().isEmpty()) { + exception = addValidationError("Template name in body cannot be null or empty", exception); + } + if (template.getHooks() == null || template.getHooks().isEmpty()) { + exception = addValidationError("Template must define at least one hook", exception); + } + } + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(templateName); + template.writeTo(out); + } + + public static MLCreateContextManagementTemplateRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLCreateContextManagementTemplateRequest) { + return (MLCreateContextManagementTemplateRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateContextManagementTemplateRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateContextManagementTemplateRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateResponse.java new file mode 100644 index 0000000000..85265bb333 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLCreateContextManagementTemplateResponse.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Getter; + +@Getter +public class MLCreateContextManagementTemplateResponse extends ActionResponse implements ToXContentObject { + public static final String TEMPLATE_NAME_FIELD = "template_name"; + public static final String STATUS_FIELD = "status"; + + private String templateName; + private String status; + + public MLCreateContextManagementTemplateResponse(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + this.status = in.readString(); + } + + public MLCreateContextManagementTemplateResponse(String templateName, String status) { + this.templateName = templateName; + this.status = status; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(templateName); + out.writeString(status); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEMPLATE_NAME_FIELD, templateName); + builder.field(STATUS_FIELD, status); + builder.endObject(); + return builder; + } + + public static MLCreateContextManagementTemplateResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLCreateContextManagementTemplateResponse) { + return (MLCreateContextManagementTemplateResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateContextManagementTemplateResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLCreateContextManagementTemplateResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateAction.java new file mode 100644 index 0000000000..6074891afa --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLDeleteContextManagementTemplateAction extends ActionType { + public static MLDeleteContextManagementTemplateAction INSTANCE = new MLDeleteContextManagementTemplateAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/delete"; + + private MLDeleteContextManagementTemplateAction() { + super(NAME, MLDeleteContextManagementTemplateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateRequest.java new file mode 100644 index 0000000000..e7b6e69200 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLDeleteContextManagementTemplateRequest extends ActionRequest { + + String templateName; + + @Builder + public MLDeleteContextManagementTemplateRequest(String templateName) { + this.templateName = templateName; + } + + public MLDeleteContextManagementTemplateRequest(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Template name cannot be null or empty", exception); + } + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(templateName); + } + + public static MLDeleteContextManagementTemplateRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLDeleteContextManagementTemplateRequest) { + return (MLDeleteContextManagementTemplateRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLDeleteContextManagementTemplateRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLDeleteContextManagementTemplateRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateResponse.java new file mode 100644 index 0000000000..415323f932 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLDeleteContextManagementTemplateResponse.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import lombok.Getter; + +@Getter +public class MLDeleteContextManagementTemplateResponse extends ActionResponse implements ToXContentObject { + public static final String TEMPLATE_NAME_FIELD = "template_name"; + public static final String STATUS_FIELD = "status"; + + private String templateName; + private String status; + + public MLDeleteContextManagementTemplateResponse(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + this.status = in.readString(); + } + + public MLDeleteContextManagementTemplateResponse(String templateName, String status) { + this.templateName = templateName; + this.status = status; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(templateName); + out.writeString(status); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEMPLATE_NAME_FIELD, templateName); + builder.field(STATUS_FIELD, status); + builder.endObject(); + return builder; + } + + public static MLDeleteContextManagementTemplateResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLDeleteContextManagementTemplateResponse) { + return (MLDeleteContextManagementTemplateResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLDeleteContextManagementTemplateResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLDeleteContextManagementTemplateResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateAction.java new file mode 100644 index 0000000000..4220dafe25 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLGetContextManagementTemplateAction extends ActionType { + public static MLGetContextManagementTemplateAction INSTANCE = new MLGetContextManagementTemplateAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/get"; + + private MLGetContextManagementTemplateAction() { + super(NAME, MLGetContextManagementTemplateResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateRequest.java new file mode 100644 index 0000000000..f8f8061868 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLGetContextManagementTemplateRequest extends ActionRequest { + + String templateName; + + @Builder + public MLGetContextManagementTemplateRequest(String templateName) { + this.templateName = templateName; + } + + public MLGetContextManagementTemplateRequest(StreamInput in) throws IOException { + super(in); + this.templateName = in.readString(); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (templateName == null || templateName.trim().isEmpty()) { + exception = addValidationError("Template name cannot be null or empty", exception); + } + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(templateName); + } + + public static MLGetContextManagementTemplateRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLGetContextManagementTemplateRequest) { + return (MLGetContextManagementTemplateRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLGetContextManagementTemplateRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLGetContextManagementTemplateRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateResponse.java new file mode 100644 index 0000000000..309d4c88af --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLGetContextManagementTemplateResponse.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; + +import lombok.Getter; + +@Getter +public class MLGetContextManagementTemplateResponse extends ActionResponse implements ToXContentObject { + + private ContextManagementTemplate template; + + public MLGetContextManagementTemplateResponse(StreamInput in) throws IOException { + super(in); + this.template = new ContextManagementTemplate(in); + } + + public MLGetContextManagementTemplateResponse(ContextManagementTemplate template) { + this.template = template; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + template.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return template.toXContent(builder, params); + } + + public static MLGetContextManagementTemplateResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLGetContextManagementTemplateResponse) { + return (MLGetContextManagementTemplateResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLGetContextManagementTemplateResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLGetContextManagementTemplateResponse", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesAction.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesAction.java new file mode 100644 index 0000000000..2b18f92e20 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import org.opensearch.action.ActionType; + +public class MLListContextManagementTemplatesAction extends ActionType { + public static MLListContextManagementTemplatesAction INSTANCE = new MLListContextManagementTemplatesAction(); + public static final String NAME = "cluster:admin/opensearch/ml/context_management/list"; + + private MLListContextManagementTemplatesAction() { + super(NAME, MLListContextManagementTemplatesResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesRequest.java new file mode 100644 index 0000000000..7f86ad63f6 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; + +@Getter +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@ToString +public class MLListContextManagementTemplatesRequest extends ActionRequest { + + int from; + int size; + + @Builder + public MLListContextManagementTemplatesRequest(int from, int size) { + this.from = from; + this.size = size; + } + + public MLListContextManagementTemplatesRequest(StreamInput in) throws IOException { + super(in); + this.from = in.readInt(); + this.size = in.readInt(); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + // No specific validation needed for list request + return exception; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeInt(from); + out.writeInt(size); + } + + public static MLListContextManagementTemplatesRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLListContextManagementTemplatesRequest) { + return (MLListContextManagementTemplatesRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLListContextManagementTemplatesRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionRequest into MLListContextManagementTemplatesRequest", e); + } + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesResponse.java new file mode 100644 index 0000000000..bc66395100 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/contextmanagement/MLListContextManagementTemplatesResponse.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.contextmanagement; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; + +import lombok.Getter; + +@Getter +public class MLListContextManagementTemplatesResponse extends ActionResponse implements ToXContentObject { + public static final String TEMPLATES_FIELD = "templates"; + public static final String TOTAL_FIELD = "total"; + + private List templates; + private long total; + + public MLListContextManagementTemplatesResponse(StreamInput in) throws IOException { + super(in); + this.templates = in.readList(ContextManagementTemplate::new); + this.total = in.readLong(); + } + + public MLListContextManagementTemplatesResponse(List templates, long total) { + this.templates = templates; + this.total = total; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(templates); + out.writeLong(total); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TOTAL_FIELD, total); + builder.startArray(TEMPLATES_FIELD); + for (ContextManagementTemplate template : templates) { + template.toXContent(builder, params); + } + builder.endArray(); + builder.endObject(); + return builder; + } + + public static MLListContextManagementTemplatesResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLListContextManagementTemplatesResponse) { + return (MLListContextManagementTemplatesResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLListContextManagementTemplatesResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("Failed to parse ActionResponse into MLListContextManagementTemplatesResponse", e); + } + } +} diff --git a/common/src/main/resources/index-mappings/ml_agent.json b/common/src/main/resources/index-mappings/ml_agent.json index 9d4deeca51..c530711fb2 100644 --- a/common/src/main/resources/index-mappings/ml_agent.json +++ b/common/src/main/resources/index-mappings/ml_agent.json @@ -43,6 +43,24 @@ "last_updated_time": { "type": "date", "format": "strict_date_time||epoch_millis" + }, + "context_management_name": { + "type": "keyword" + }, + "context_management": { + "type": "object", + "properties": { + "name": { + "type": "keyword" + }, + "description": { + "type": "text" + }, + "hooks": { + "type": "object", + "enabled": false + } + } } } } diff --git a/common/src/main/resources/index-mappings/ml_context_management_templates.json b/common/src/main/resources/index-mappings/ml_context_management_templates.json new file mode 100644 index 0000000000..534be6702d --- /dev/null +++ b/common/src/main/resources/index-mappings/ml_context_management_templates.json @@ -0,0 +1,26 @@ +{ + "dynamic": false, + "properties": { + "name": { + "type": "keyword" + }, + "description": { + "type": "text" + }, + "hooks": { + "type": "object", + "enabled": false + }, + "created_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "last_modified": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "created_by": { + "type": "keyword" + } + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index da2f5f5c1e..cf0747603e 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -29,6 +29,7 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; import org.opensearch.search.SearchModule; public class MLAgentTest { @@ -65,6 +66,8 @@ public void constructor_NullName() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -86,6 +89,8 @@ public void constructor_NullType() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -107,6 +112,8 @@ public void constructor_NullLLMSpec() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -128,6 +135,8 @@ public void constructor_DuplicateTool() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -146,6 +155,8 @@ public void writeTo() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -174,6 +185,8 @@ public void writeTo_NullLLM() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -197,6 +210,8 @@ public void writeTo_NullTools() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -220,6 +235,8 @@ public void writeTo_NullParameters() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -243,6 +260,8 @@ public void writeTo_NullMemory() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -279,6 +298,8 @@ public void toXContent() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); @@ -336,6 +357,8 @@ public void fromStream() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); BytesStreamOutput output = new BytesStreamOutput(); @@ -367,6 +390,8 @@ public void constructor_InvalidAgentType() { Instant.EPOCH, "test", false, + null, + null, null ); } @@ -386,6 +411,8 @@ public void constructor_NonConversationalNoLLM() { Instant.EPOCH, "test", false, + null, + null, null ); assertNotNull(agent); // Ensuring object creation was successful without throwing an exception @@ -396,7 +423,22 @@ public void constructor_NonConversationalNoLLM() { @Test public void writeTo_ReadFrom_HiddenFlag_VersionCompatibility() throws IOException { - MLAgent agent = new MLAgent("test", "FLOW", "test", null, null, null, null, Instant.EPOCH, Instant.EPOCH, "test", true, null); + MLAgent agent = new MLAgent( + "test", + "FLOW", + "test", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test", + true, + null, + null, + null + ); // Serialize and deserialize with an older version BytesStreamOutput output = new BytesStreamOutput(); @@ -460,6 +502,8 @@ public void getTags() { Instant.EPOCH, "test_app", true, + null, + null, null ); @@ -486,6 +530,8 @@ public void getTags_NullValues() { Instant.EPOCH, "test_app", null, + null, + null, null ); @@ -497,4 +543,325 @@ public void getTags_NullValues() { assertFalse(tagsMap.containsKey("memory_type")); assertFalse(tagsMap.containsKey("_llm_interface")); } + + @Test + public void constructor_ConflictingContextManagement() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Cannot specify both context_management_name and context_management"); + + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + new ContextManagementTemplate(), + null + ); + } + + @Test + public void hasContextManagement_WithTemplateName() { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + assertTrue(agent.hasContextManagement()); + assertTrue(agent.hasContextManagementTemplate()); + assertEquals("template_name", agent.getContextManagementTemplateName()); + assertNull(agent.getInlineContextManagement()); + } + + @Test + public void hasContextManagement_WithInlineConfig() { + ContextManagementTemplate template = ContextManagementTemplate + .builder() + .name("test_template") + .description("test description") + .hooks(Map.of("POST_TOOL", List.of())) + .build(); + + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + template, + null + ); + + assertTrue(agent.hasContextManagement()); + assertFalse(agent.hasContextManagementTemplate()); + assertNull(agent.getContextManagementTemplateName()); + assertEquals(template, agent.getInlineContextManagement()); + } + + @Test + public void hasContextManagement_NoContextManagement() { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + null, + null + ); + + assertFalse(agent.hasContextManagement()); + assertFalse(agent.hasContextManagementTemplate()); + assertNull(agent.getContextManagementTemplateName()); + assertNull(agent.getInlineContextManagement()); + } + + @Test + public void writeTo_ReadFrom_ContextManagementName() throws IOException { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(CommonValue.VERSION_3_3_0); + agent.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + streamInput.setVersion(CommonValue.VERSION_3_3_0); + MLAgent deserializedAgent = new MLAgent(streamInput); + + assertEquals("template_name", deserializedAgent.getContextManagementTemplateName()); + assertNull(deserializedAgent.getInlineContextManagement()); + assertTrue(deserializedAgent.hasContextManagement()); + assertTrue(deserializedAgent.hasContextManagementTemplate()); + } + + @Test + public void writeTo_ReadFrom_ContextManagementInline() throws IOException { + ContextManagementTemplate template = ContextManagementTemplate + .builder() + .name("test_template") + .description("test description") + .hooks(Map.of("POST_TOOL", List.of())) + .build(); + + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + template, + null + ); + + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(CommonValue.VERSION_3_3_0); + agent.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + streamInput.setVersion(CommonValue.VERSION_3_3_0); + MLAgent deserializedAgent = new MLAgent(streamInput); + + assertNull(deserializedAgent.getContextManagementTemplateName()); + assertNotNull(deserializedAgent.getInlineContextManagement()); + assertEquals("test_template", deserializedAgent.getInlineContextManagement().getName()); + assertEquals("test description", deserializedAgent.getInlineContextManagement().getDescription()); + assertTrue(deserializedAgent.hasContextManagement()); + assertFalse(deserializedAgent.hasContextManagementTemplate()); + } + + @Test + public void writeTo_ReadFrom_ContextManagement_VersionCompatibility() throws IOException { + MLAgent agent = new MLAgent( + "test_agent", + MLAgentType.FLOW.name(), + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + // Serialize with older version (before context management support) + BytesStreamOutput output = new BytesStreamOutput(); + output.setVersion(CommonValue.VERSION_3_2_0); + agent.writeTo(output); + + StreamInput streamInput = output.bytes().streamInput(); + streamInput.setVersion(CommonValue.VERSION_3_2_0); + MLAgent deserializedAgent = new MLAgent(streamInput); + + // Context management fields should be null for older versions + assertNull(deserializedAgent.getContextManagementTemplateName()); + assertNull(deserializedAgent.getInlineContextManagement()); + assertFalse(deserializedAgent.hasContextManagement()); + } + + @Test + public void parse_WithContextManagementName() throws IOException { + String jsonStr = "{\"name\":\"test\",\"type\":\"FLOW\",\"context_management_name\":\"template_name\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLAgent agent = MLAgent.parseFromUserInput(parser); + + assertEquals("test", agent.getName()); + assertEquals("FLOW", agent.getType()); + assertEquals("template_name", agent.getContextManagementTemplateName()); + assertNull(agent.getInlineContextManagement()); + assertTrue(agent.hasContextManagement()); + assertTrue(agent.hasContextManagementTemplate()); + } + + @Test + public void parse_WithInlineContextManagement() throws IOException { + String jsonStr = + "{\"name\":\"test\",\"type\":\"FLOW\",\"context_management\":{\"name\":\"inline_template\",\"description\":\"test\",\"hooks\":{\"POST_TOOL\":[]}}}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); + parser.nextToken(); + MLAgent agent = MLAgent.parseFromUserInput(parser); + + assertEquals("test", agent.getName()); + assertEquals("FLOW", agent.getType()); + assertNull(agent.getContextManagementTemplateName()); + assertNotNull(agent.getInlineContextManagement()); + assertEquals("inline_template", agent.getInlineContextManagement().getName()); + assertEquals("test", agent.getInlineContextManagement().getDescription()); + assertTrue(agent.hasContextManagement()); + assertFalse(agent.hasContextManagementTemplate()); + } + + @Test + public void toXContent_WithContextManagementName() throws IOException { + MLAgent agent = new MLAgent( + "test", + "FLOW", + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + "template_name", + null, + null + ); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + agent.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + assertTrue(content.contains("\"context_management_name\":\"template_name\"")); + assertFalse(content.contains("\"context_management\":")); + } + + @Test + public void toXContent_WithInlineContextManagement() throws IOException { + ContextManagementTemplate template = ContextManagementTemplate + .builder() + .name("inline_template") + .description("test description") + .hooks(Map.of("POST_TOOL", List.of())) + .build(); + + MLAgent agent = new MLAgent( + "test", + "FLOW", + "test description", + null, + null, + null, + null, + Instant.EPOCH, + Instant.EPOCH, + "test_app", + false, + null, + template, + null + ); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + agent.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + assertFalse(content.contains("\"context_management_name\":")); + assertTrue(content.contains("\"context_management\":")); + assertTrue(content.contains("\"inline_template\"")); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounterTest.java b/common/src/test/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounterTest.java new file mode 100644 index 0000000000..8eb5d7978f --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/contextmanager/CharacterBasedTokenCounterTest.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.contextmanager; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +/** + * Unit tests for CharacterBasedTokenCounter. + */ +public class CharacterBasedTokenCounterTest { + + private CharacterBasedTokenCounter tokenCounter; + + @Before + public void setUp() { + tokenCounter = new CharacterBasedTokenCounter(); + } + + @Test + public void testCountWithNullText() { + Assert.assertEquals(0, tokenCounter.count(null)); + } + + @Test + public void testCountWithEmptyText() { + Assert.assertEquals(0, tokenCounter.count("")); + } + + @Test + public void testCountWithShortText() { + String text = "Hi"; + int expectedTokens = (int) Math.ceil(text.length() / 4.0); + Assert.assertEquals(expectedTokens, tokenCounter.count(text)); + } + + @Test + public void testCountWithMediumText() { + String text = "This is a test message"; + int expectedTokens = (int) Math.ceil(text.length() / 4.0); + Assert.assertEquals(expectedTokens, tokenCounter.count(text)); + } + + @Test + public void testCountWithLongText() { + String text = "This is a very long text that should result in multiple tokens when counted using the character-based approach."; + int expectedTokens = (int) Math.ceil(text.length() / 4.0); + Assert.assertEquals(expectedTokens, tokenCounter.count(text)); + } + + @Test + public void testTruncateFromEndWithNullText() { + Assert.assertNull(tokenCounter.truncateFromEnd(null, 10)); + } + + @Test + public void testTruncateFromEndWithEmptyText() { + Assert.assertEquals("", tokenCounter.truncateFromEnd("", 10)); + } + + @Test + public void testTruncateFromEndWithShortText() { + String text = "Short"; + String result = tokenCounter.truncateFromEnd(text, 10); + Assert.assertEquals(text, result); + } + + @Test + public void testTruncateFromEndWithLongText() { + String text = "This is a very long text that needs to be truncated"; + String result = tokenCounter.truncateFromEnd(text, 5); + + Assert.assertNotNull(result); + Assert.assertTrue(result.length() < text.length()); + Assert.assertTrue(result.length() <= 5 * 4); // 5 tokens * 4 chars per token + Assert.assertTrue(text.startsWith(result)); + } + + @Test + public void testTruncateFromBeginningWithNullText() { + Assert.assertNull(tokenCounter.truncateFromBeginning(null, 10)); + } + + @Test + public void testTruncateFromBeginningWithEmptyText() { + Assert.assertEquals("", tokenCounter.truncateFromBeginning("", 10)); + } + + @Test + public void testTruncateFromBeginningWithShortText() { + String text = "Short"; + String result = tokenCounter.truncateFromBeginning(text, 10); + Assert.assertEquals(text, result); + } + + @Test + public void testTruncateFromBeginningWithLongText() { + String text = "This is a very long text that needs to be truncated"; + String result = tokenCounter.truncateFromBeginning(text, 5); + + Assert.assertNotNull(result); + Assert.assertTrue(result.length() < text.length()); + Assert.assertTrue(result.length() <= 5 * 4); // 5 tokens * 4 chars per token + Assert.assertTrue(text.endsWith(result)); + } + + @Test + public void testTruncateMiddleWithNullText() { + Assert.assertNull(tokenCounter.truncateMiddle(null, 10)); + } + + @Test + public void testTruncateMiddleWithEmptyText() { + Assert.assertEquals("", tokenCounter.truncateMiddle("", 10)); + } + + @Test + public void testTruncateMiddleWithShortText() { + String text = "Short"; + String result = tokenCounter.truncateMiddle(text, 10); + Assert.assertEquals(text, result); + } + + @Test + public void testTruncateMiddleWithLongText() { + String text = "This is a very long text that needs to be truncated from the middle"; + String result = tokenCounter.truncateMiddle(text, 5); + + Assert.assertNotNull(result); + Assert.assertTrue(result.length() < text.length()); + Assert.assertTrue(result.length() <= 5 * 4); // 5 tokens * 4 chars per token + + // Result should contain parts from both beginning and end + int halfChars = (5 * 4) / 2; + String expectedBeginning = text.substring(0, halfChars); + String expectedEnd = text.substring(text.length() - halfChars); + + Assert.assertTrue(result.startsWith(expectedBeginning)); + Assert.assertTrue(result.endsWith(expectedEnd)); + } + + @Test + public void testTruncateConsistency() { + String text = "This is a test text for truncation consistency"; + int maxTokens = 3; + + String fromEnd = tokenCounter.truncateFromEnd(text, maxTokens); + String fromBeginning = tokenCounter.truncateFromBeginning(text, maxTokens); + String fromMiddle = tokenCounter.truncateMiddle(text, maxTokens); + + // All truncated results should have similar token counts + int tokensFromEnd = tokenCounter.count(fromEnd); + int tokensFromBeginning = tokenCounter.count(fromBeginning); + int tokensFromMiddle = tokenCounter.count(fromMiddle); + + Assert.assertTrue(tokensFromEnd <= maxTokens); + Assert.assertTrue(tokensFromBeginning <= maxTokens); + Assert.assertTrue(tokensFromMiddle <= maxTokens); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java index 81a173dfde..58921f71b8 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLAgentGetResponseTest.java @@ -96,6 +96,8 @@ public void writeTo() throws IOException { Instant.EPOCH, "test", false, + null, + null, null ); MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); @@ -115,7 +117,7 @@ public void writeTo() throws IOException { @Test public void toXContent() throws IOException { - mlAgent = new MLAgent("mock", MLAgentType.FLOW.name(), "test", null, null, null, null, null, null, "test", false, null); + mlAgent = new MLAgent("mock", MLAgentType.FLOW.name(), "test", null, null, null, null, null, null, "test", false, null, null, null); MLAgentGetResponse mlAgentGetResponse = MLAgentGetResponse.builder().mlAgent(mlAgent).build(); XContentBuilder builder = XContentFactory.jsonBuilder(); ToXContent.Params params = EMPTY_PARAMS; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java index 94cbbeb7dd..da7f3d5623 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java @@ -10,6 +10,9 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.Before; import org.junit.Rule; @@ -21,6 +24,8 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; public class MLRegisterAgentRequestTest { @@ -111,4 +116,260 @@ public void writeTo(StreamOutput out) throws IOException { }; MLRegisterAgentRequest.fromActionRequest(actionRequest); } + + @Test + public void validate_ContextManagementConflict() { + // Create agent with both context management name and inline configuration + ContextManagementTemplate contextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(createValidHooks()) + .build(); + + // This should throw an exception during MLAgent construction + try { + MLAgent agentWithConflict = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagementName("template_name") + .contextManagement(contextManagement) + .build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Cannot specify both context_management_name and context_management")); + } + } + + @Test + public void validate_ContextManagementTemplateName_Valid() { + MLAgent agentWithTemplateName = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagementName("valid_template_name") + .build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithTemplateName); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_ContextManagementTemplateName_Empty() { + // Test empty template name - this should be caught at request validation level + MLAgent agentWithEmptyName = MLAgent.builder().name("test_agent").type("flow").contextManagementName("").build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithEmptyName); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue(exception.toString().contains("Context management template name cannot be null or empty")); + } + + @Test + public void validate_ContextManagementTemplateName_TooLong() { + // Test template name that's too long + String longName = "a".repeat(257); + MLAgent agentWithLongName = MLAgent.builder().name("test_agent").type("flow").contextManagementName(longName).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithLongName); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue(exception.toString().contains("Context management template name cannot exceed 256 characters")); + } + + @Test + public void validate_ContextManagementTemplateName_InvalidCharacters() { + // Test template name with invalid characters + MLAgent agentWithInvalidName = MLAgent.builder().name("test_agent").type("flow").contextManagementName("invalid@name#").build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithInvalidName); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue( + exception + .toString() + .contains("Context management template name can only contain letters, numbers, underscores, hyphens, and dots") + ); + } + + @Test + public void validate_InlineContextManagement_Valid() { + ContextManagementTemplate validContextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(createValidHooks()) + .build(); + + MLAgent agentWithInlineConfig = MLAgent.builder().name("test_agent").type("flow").contextManagement(validContextManagement).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithInlineConfig); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_InlineContextManagement_InvalidHookName() { + // Create a context management template with invalid hook name but valid structure + // This should pass MLAgent validation but fail request validation + Map> invalidHooks = new HashMap<>(); + invalidHooks.put("INVALID_HOOK", Arrays.asList(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate invalidContextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(invalidHooks) + .build(); + + MLAgent agentWithInvalidConfig = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagement(invalidContextManagement) + .build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithInvalidConfig); + ActionRequestValidationException exception = request.validate(); + + assertNotNull(exception); + assertTrue(exception.toString().contains("Invalid hook name: INVALID_HOOK")); + } + + @Test + public void validate_InlineContextManagement_EmptyHooks() { + ContextManagementTemplate emptyHooksTemplate = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(new HashMap<>()) + .build(); + + // This should throw an exception during MLAgent construction due to invalid context management + try { + MLAgent agentWithEmptyHooks = MLAgent.builder().name("test_agent").type("flow").contextManagement(emptyHooksTemplate).build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Invalid context management configuration")); + } + } + + @Test + public void validate_InlineContextManagement_InvalidContextManagerConfig() { + Map> hooksWithInvalidConfig = new HashMap<>(); + hooksWithInvalidConfig + .put( + "POST_TOOL", + Arrays + .asList( + new ContextManagerConfig(null, null, null) // Invalid: null type + ) + ); + + ContextManagementTemplate invalidTemplate = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(hooksWithInvalidConfig) + .build(); + + // This should throw an exception during MLAgent construction due to invalid context management + try { + MLAgent agentWithInvalidConfig = MLAgent.builder().name("test_agent").type("flow").contextManagement(invalidTemplate).build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Invalid context management configuration")); + } + } + + @Test + public void validate_NoContextManagement_Valid() { + MLAgent agentWithoutContextManagement = MLAgent.builder().name("test_agent").type("flow").build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithoutContextManagement); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_ContextManagementTemplateName_NullValue() { + // Test null template name - this should pass validation since null is acceptable + MLAgent agentWithNullName = MLAgent.builder().name("test_agent").type("flow").contextManagementName(null).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullName); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + @Test + public void validate_ContextManagementTemplateName_Null() { + // Test null template name validation + MLAgent agentWithNullName = MLAgent.builder().name("test_agent").type("flow").contextManagementName(null).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullName); + ActionRequestValidationException exception = request.validate(); + + // This should pass since null is handled differently than empty + assertNull(exception); + } + + @Test + public void validate_InlineContextManagement_NullHooks() { + // Test inline context management with null hooks + ContextManagementTemplate contextManagementWithNullHooks = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(null) + .build(); + + MLAgent agentWithNullHooks = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagement(contextManagementWithNullHooks) + .build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullHooks); + ActionRequestValidationException exception = request.validate(); + + // Should pass since null hooks are handled gracefully + assertNull(exception); + } + + @Test + public void validate_HookName_AllValidTypes() { + // Test all valid hook names to improve branch coverage + Map> allValidHooks = new HashMap<>(); + allValidHooks.put("POST_TOOL", Arrays.asList(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + allValidHooks.put("PRE_LLM", Arrays.asList(new ContextManagerConfig("SummarizationManager", null, null))); + allValidHooks.put("PRE_TOOL", Arrays.asList(new ContextManagerConfig("MemoryManager", null, null))); + allValidHooks.put("POST_LLM", Arrays.asList(new ContextManagerConfig("ConversationManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate + .builder() + .name("test_template") + .hooks(allValidHooks) + .build(); + + MLAgent agentWithAllHooks = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithAllHooks); + ActionRequestValidationException exception = request.validate(); + + assertNull(exception); + } + + /** + * Helper method to create valid hooks configuration for testing + */ + private Map> createValidHooks() { + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", Arrays.asList(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + hooks.put("PRE_LLM", Arrays.asList(new ContextManagerConfig("SummarizationManager", null, null))); + return hooks; + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java new file mode 100644 index 0000000000..83180d9551 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java @@ -0,0 +1,182 @@ +package org.opensearch.ml.engine.agents; + +import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.CHAT_HISTORY; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.SYSTEM_PROMPT_FIELD; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.hooks.EnhancedPostToolEvent; +import org.opensearch.ml.common.hooks.HookRegistry; +import org.opensearch.ml.common.hooks.PreLLMEvent; +import org.opensearch.ml.common.memory.Memory; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.memory.ConversationIndexMemory; + +public class AgentContextUtil { + private static final Logger log = LogManager.getLogger(AgentContextUtil.class); + + public static ContextManagerContext buildContextManagerContextForToolOutput( + String toolOutput, + Map parameters, + List toolSpecs, + Memory memory + ) { + ContextManagerContext.ContextManagerContextBuilder builder = ContextManagerContext.builder(); + + String systemPrompt = parameters.get(SYSTEM_PROMPT_FIELD); + if (systemPrompt != null) { + builder.systemPrompt(systemPrompt); + } + + String userPrompt = parameters.get(QUESTION); + if (userPrompt != null) { + builder.userPrompt(userPrompt); + } + + if (toolSpecs != null) { + builder.toolConfigs(toolSpecs); + } + + Map contextParameters = new HashMap<>(); + contextParameters.putAll(parameters); + contextParameters.put("_current_tool_output", toolOutput); + builder.parameters(contextParameters); + + return builder.build(); + } + + public static Object extractProcessedToolOutput(ContextManagerContext context) { + if (context.getParameters() != null) { + return context.getParameters().get("_current_tool_output"); + } + return null; + } + + public static Object extractFromContext(ContextManagerContext context, String key) { + if (context.getParameters() != null) { + return context.getParameters().get(key); + } + return null; + } + + public static ContextManagerContext buildContextManagerContext( + Map parameters, + List interactions, + List toolSpecs, + Memory memory + ) { + ContextManagerContext.ContextManagerContextBuilder builder = ContextManagerContext.builder(); + + String systemPrompt = parameters.get(SYSTEM_PROMPT_FIELD); + if (systemPrompt != null) { + builder.systemPrompt(systemPrompt); + } + + String userPrompt = parameters.get(QUESTION); + if (userPrompt != null) { + builder.userPrompt(userPrompt); + } + + if (memory instanceof ConversationIndexMemory) { + String chatHistory = parameters.get(CHAT_HISTORY); + // TODO to add chatHistory into context, currently there is no context manager working on chat_history + } + + if (toolSpecs != null) { + builder.toolConfigs(toolSpecs); + } + + builder.toolInteractions(interactions != null ? interactions : new ArrayList<>()); + + Map contextParameters = new HashMap<>(); + contextParameters.putAll(parameters); + builder.parameters(contextParameters); + + return builder.build(); + } + + public static Object emitPostToolHook( + Object toolOutput, + Map parameters, + List toolSpecs, + Memory memory, + HookRegistry hookRegistry + ) { + if (hookRegistry != null) { + try { + if (toolOutput == null) { + log.warn("Tool output is null, skipping POST_TOOL hook"); + return null; + } + ContextManagerContext context = buildContextManagerContextForToolOutput( + StringUtils.toJson(toolOutput), + parameters, + toolSpecs, + memory + ); + EnhancedPostToolEvent event = new EnhancedPostToolEvent(null, null, context, new HashMap<>()); + hookRegistry.emit(event); + + Object processedOutput = extractProcessedToolOutput(context); + return processedOutput != null ? processedOutput : toolOutput; + } catch (Exception e) { + log.error("Failed to emit POST_TOOL hook event", e); + return toolOutput; + } + } + return toolOutput; + } + + public static ContextManagerContext emitPreLLMHook( + Map parameters, + List interactions, + List toolSpecs, + Memory memory, + HookRegistry hookRegistry + ) { + ContextManagerContext context = buildContextManagerContext(parameters, interactions, toolSpecs, memory); + + try { + PreLLMEvent event = new PreLLMEvent(context, new HashMap<>()); + hookRegistry.emit(event); + return context; + + } catch (Exception e) { + log.error("Failed to emit PRE_LLM hook event", e); + return context; + } + } + + public static void updateParametersFromContext(Map parameters, ContextManagerContext context) { + if (context.getSystemPrompt() != null) { + parameters.put(SYSTEM_PROMPT_FIELD, context.getSystemPrompt()); + } + + if (context.getUserPrompt() != null) { + parameters.put(QUESTION, context.getUserPrompt()); + } + + if (context.getChatHistory() != null && !context.getChatHistory().isEmpty()) { + } + + if (context.getToolInteractions() != null && !context.getToolInteractions().isEmpty()) { + parameters.put(INTERACTIONS, ", " + String.join(", ", context.getToolInteractions())); + } + + if (context.getParameters() != null) { + for (Map.Entry entry : context.getParameters().entrySet()) { + parameters.put(entry.getKey(), entry.getValue()); + + } + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 87a266a490..bb37abe940 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -58,7 +58,9 @@ import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.ModelProvider; import org.opensearch.ml.common.agent.ModelProviderFactory; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.memory.Memory; @@ -71,6 +73,9 @@ import org.opensearch.ml.common.settings.SettingsChangeListener; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.engine.Executable; +import org.opensearch.ml.engine.algorithms.contextmanager.SlidingWindowManager; +import org.opensearch.ml.engine.algorithms.contextmanager.SummarizationManager; +import org.opensearch.ml.engine.algorithms.contextmanager.ToolsOutputTruncateManager; import org.opensearch.ml.engine.annotation.Function; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; @@ -209,6 +214,12 @@ public void execute(Input input, ActionListener listener, TransportChann ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLAgent mlAgent = MLAgent.parse(parser); + // Use existing HookRegistry from AgentMLInput if available (set by MLExecuteTaskRunner for template + // references) + // Otherwise create a fresh HookRegistry for agent execution + final HookRegistry hookRegistry = agentMLInput.getHookRegistry() != null + ? agentMLInput.getHookRegistry() + : new HookRegistry(); if (isMultiTenancyEnabled && !Objects.equals(tenantId, mlAgent.getTenantId())) { listener .onFailure( @@ -287,7 +298,8 @@ public void execute(Input input, ActionListener listener, TransportChann outputs, modelTensors, mlAgent, - channel + channel, + hookRegistry ); }, e -> { log.error("Failed to get existing interaction for regeneration", e); @@ -304,7 +316,8 @@ public void execute(Input input, ActionListener listener, TransportChann outputs, modelTensors, mlAgent, - channel + channel, + hookRegistry ); } }, ex -> { @@ -319,8 +332,9 @@ public void execute(Input input, ActionListener listener, TransportChann ConversationIndexMemory.Factory factory = (ConversationIndexMemory.Factory) memoryFactoryMap .get(memorySpec.getType()); - if (memoryFactory != null) { - // memoryId exists, so create returns an object with existing memory, therefore name can + if (factory != null) { + // memoryId exists, so create returns an object with existing + // memory, therefore name can // be null Map memoryParams = createMemoryParams( question, @@ -330,7 +344,7 @@ public void execute(Input input, ActionListener listener, TransportChann inputDataSet.getParameters().get(MEMORY_CONTAINER_ID_FIELD) ); - memoryFactory + factory .create( memoryParams, ActionListener @@ -345,7 +359,8 @@ public void execute(Input input, ActionListener listener, TransportChann modelTensors, listener, createdMemory, - channel + channel, + hookRegistry ), ex -> { log.error("Failed to find memory with memory_id: {}", memoryId, ex); @@ -366,7 +381,8 @@ public void execute(Input input, ActionListener listener, TransportChann modelTensors, listener, null, - channel + channel, + hookRegistry ); } } catch (Exception e) { @@ -396,10 +412,11 @@ public void execute(Input input, ActionListener listener, TransportChann /** * save root interaction and start execute the agent - * @param listener callback listener - * @param memory memory instance + * + * @param listener callback listener + * @param memory memory instance * @param inputDataSet input - * @param mlAgent agent to run + * @param mlAgent agent to run */ private void saveRootInteractionAndExecute( ActionListener listener, @@ -410,7 +427,8 @@ private void saveRootInteractionAndExecute( List outputs, List modelTensors, MLAgent mlAgent, - TransportChannel channel + TransportChannel channel, + HookRegistry hookRegistry ) { String appType = mlAgent.getAppType(); String question = inputDataSet.getParameters().get(QUESTION); @@ -444,7 +462,8 @@ private void saveRootInteractionAndExecute( modelTensors, listener, memory, - channel + channel, + hookRegistry ), e -> { log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e); @@ -453,7 +472,19 @@ private void saveRootInteractionAndExecute( ) ); } else { - executeAgent(inputDataSet, mlTask, isAsync, memory.getId(), mlAgent, outputs, modelTensors, listener, memory, channel); + executeAgent( + inputDataSet, + mlTask, + isAsync, + memory.getId(), + mlAgent, + outputs, + modelTensors, + listener, + memory, + channel, + hookRegistry + ); } }, ex -> { log.error("Failed to create parent interaction", ex); @@ -461,6 +492,210 @@ private void saveRootInteractionAndExecute( })); } + /** + * Process context management configuration and register context managers in + * hook registry + * + * @param mlAgent the ML agent with context management configuration + * @param hookRegistry the hook registry to register context managers with + * @param inputDataSet the input dataset to update with context management info + */ + private void processContextManagement(MLAgent mlAgent, HookRegistry hookRegistry, RemoteInferenceInputDataSet inputDataSet) { + try { + // Check if context_management is already specified in runtime parameters + String runtimeContextManagement = inputDataSet.getParameters().get("context_management"); + if (runtimeContextManagement != null && !runtimeContextManagement.trim().isEmpty()) { + log.info("Using runtime context management parameter: {}", runtimeContextManagement); + return; // Runtime parameter takes precedence, let MLExecuteTaskRunner handle it + } + + ContextManagementTemplate template = null; + String templateName = null; + + if (mlAgent.hasContextManagementTemplate()) { + // Template reference - would need to be resolved from template service + templateName = mlAgent.getContextManagementTemplateName(); + log.info("Agent '{}' has context management template reference: {}", mlAgent.getName(), templateName); + // For now, we'll pass the template name to parameters for MLExecuteTaskRunner + // to handle + inputDataSet.getParameters().put("context_management", templateName); + return; // Let MLExecuteTaskRunner handle template resolution + } else if (mlAgent.getInlineContextManagement() != null) { + // Inline template - process directly + template = mlAgent.getInlineContextManagement(); + templateName = template.getName(); + log.info("Agent '{}' has inline context management configuration: {}", mlAgent.getName(), templateName); + } + + if (template != null) { + // Process inline context management template + processInlineContextManagement(template, hookRegistry); + // Mark as processed to prevent MLExecuteTaskRunner from processing it again + inputDataSet.getParameters().put("context_management_processed", "true"); + inputDataSet.getParameters().put("context_management", templateName); + } + } catch (Exception e) { + log.error("Failed to process context management for agent '{}': {}", mlAgent.getName(), e.getMessage(), e); + // Don't fail the entire execution, just log the error + } + } + + /** + * Process inline context management template and register context managers + * + * @param template the context management template + * @param hookRegistry the hook registry to register with + */ + private void processInlineContextManagement(ContextManagementTemplate template, HookRegistry hookRegistry) { + try { + log.debug("Processing inline context management template: {}", template.getName()); + + // Fresh HookRegistry ensures no duplicate registrations + + // Create context managers from template configuration + List contextManagers = createContextManagers(template); + + if (!contextManagers.isEmpty()) { + // Create hook provider and register with hook registry + org.opensearch.ml.common.contextmanager.ContextManagerHookProvider hookProvider = + new org.opensearch.ml.common.contextmanager.ContextManagerHookProvider(contextManagers); + + // Update hook configuration based on template + hookProvider.updateHookConfiguration(template.getHooks()); + + // Register hooks with the registry + hookProvider.registerHooks(hookRegistry); + + log.info("Successfully registered {} context managers from template '{}'", contextManagers.size(), template.getName()); + } else { + log.warn("No context managers created from template '{}'", template.getName()); + } + } catch (Exception e) { + log.error("Failed to process inline context management template '{}': {}", template.getName(), e.getMessage(), e); + } + } + + /** + * Create context managers from template configuration + * + * @param template the context management template + * @return list of created context managers + */ + private List createContextManagers(ContextManagementTemplate template) { + List managers = new ArrayList<>(); + + try { + // Iterate through all hooks and their configurations + for (Map.Entry> entry : template + .getHooks() + .entrySet()) { + String hookName = entry.getKey(); + List configs = entry.getValue(); + + log.debug("Processing hook '{}' with {} configurations", hookName, configs.size()); + + for (org.opensearch.ml.common.contextmanager.ContextManagerConfig config : configs) { + try { + org.opensearch.ml.common.contextmanager.ContextManager manager = createContextManager(config); + if (manager != null) { + managers.add(manager); + log.debug("Created context manager: {} for hook: {}", config.getType(), hookName); + } + } catch (Exception e) { + log + .error( + "Failed to create context manager of type '{}' for hook '{}': {}", + config.getType(), + hookName, + e.getMessage(), + e + ); + } + } + } + } catch (Exception e) { + log.error("Failed to create context managers from template: {}", e.getMessage(), e); + } + + return managers; + } + + /** + * Create a single context manager from configuration + * + * @param config the context manager configuration + * @return the created context manager or null if creation failed + */ + private org.opensearch.ml.common.contextmanager.ContextManager createContextManager( + org.opensearch.ml.common.contextmanager.ContextManagerConfig config + ) { + try { + String type = config.getType(); + Map managerConfig = config.getConfig(); + + log.debug("Creating context manager of type: {}", type); + + // Create context manager based on type + switch (type) { + case "ToolsOutputTruncateManager": + return createToolsOutputTruncateManager(managerConfig); + case "SummarizationManager": + case "SummarizingManager": + return createSummarizationManager(managerConfig); + case "MemoryManager": + return createMemoryManager(managerConfig); + case "ConversationManager": + return createConversationManager(managerConfig); + default: + log.warn("Unknown context manager type: {}", type); + return null; + } + } catch (Exception e) { + log.error("Failed to create context manager: {}", e.getMessage(), e); + return null; + } + } + + /** + * Create ToolsOutputTruncateManager + */ + private org.opensearch.ml.common.contextmanager.ContextManager createToolsOutputTruncateManager(Map config) { + log.debug("Creating ToolsOutputTruncateManager with config: {}", config); + ToolsOutputTruncateManager manager = new ToolsOutputTruncateManager(); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + + /** + * Create SummarizationManager + */ + private org.opensearch.ml.common.contextmanager.ContextManager createSummarizationManager(Map config) { + log.debug("Creating SummarizationManager with config: {}", config); + SummarizationManager manager = new SummarizationManager(client); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + + /** + * Create SlidingWindowManager (used for MemoryManager type) + */ + private org.opensearch.ml.common.contextmanager.ContextManager createMemoryManager(Map config) { + log.debug("Creating SlidingWindowManager (MemoryManager) with config: {}", config); + SlidingWindowManager manager = new SlidingWindowManager(); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + + /** + * Create ConversationManager (placeholder - using SummarizationManager for now) + */ + private org.opensearch.ml.common.contextmanager.ContextManager createConversationManager(Map config) { + log.debug("Creating ConversationManager (using SummarizationManager as placeholder) with config: {}", config); + SummarizationManager manager = new SummarizationManager(client); + manager.initialize(config != null ? config : new HashMap<>()); + return manager; + } + private void executeAgent( RemoteInferenceInputDataSet inputDataSet, MLTask mlTask, @@ -471,7 +706,8 @@ private void executeAgent( List modelTensors, ActionListener listener, Memory memory, - TransportChannel channel + TransportChannel channel, + HookRegistry hookRegistry ) { String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null; if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) { @@ -480,10 +716,17 @@ private void executeAgent( return; } - MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent); + // Check for agent-level context management configuration (following connector + // pattern) + if (mlAgent.hasContextManagement()) { + processContextManagement(mlAgent, hookRegistry, inputDataSet); + } + + MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent, hookRegistry); String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); - // If async is true, index ML task and return the taskID. Also add memoryID to the task if it exists + // If async is true, index ML task and return the taskID. Also add memoryID to + // the task if it exists if (isAsync) { Map agentResponse = new HashMap<>(); if (memoryId != null && !memoryId.isEmpty()) { @@ -620,7 +863,7 @@ private ActionListener createAsyncTaskUpdater( } @VisibleForTesting - protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { + protected MLAgentRunner getAgentRunner(MLAgent mlAgent, HookRegistry hookRegistry) { final MLAgentType agentType = MLAgentType.from(mlAgent.getType().toUpperCase(Locale.ROOT)); switch (agentType) { case FLOW: @@ -654,7 +897,8 @@ protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { toolFactories, memoryFactoryMap, sdkClient, - encryptor + encryptor, + hookRegistry ); case PLAN_EXECUTE_AND_REFLECT: return new MLPlanExecuteAndReflectAgentRunner( @@ -665,7 +909,8 @@ protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { toolFactories, memoryFactoryMap, sdkClient, - encryptor + encryptor, + hookRegistry ); default: throw new IllegalArgumentException("Unsupported agent type: " + mlAgent.getType()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 21b1bc5762..f371e5244c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -63,7 +63,9 @@ import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.memory.Message; import org.opensearch.ml.common.output.model.ModelTensor; @@ -72,14 +74,13 @@ import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.function_calling.FunctionCalling; import org.opensearch.ml.engine.function_calling.FunctionCallingFactory; import org.opensearch.ml.engine.function_calling.LLMMessage; import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.engine.tools.MLModelTool; -import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; -import org.opensearch.ml.repackage.com.google.common.collect.Lists; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; @@ -138,6 +139,7 @@ public class MLChatAgentRunner implements MLAgentRunner { private SdkClient sdkClient; private Encryptor encryptor; private StreamingWrapper streamingWrapper; + private static HookRegistry hookRegistry; public MLChatAgentRunner( Client client, @@ -148,6 +150,20 @@ public MLChatAgentRunner( Map memoryFactoryMap, SdkClient sdkClient, Encryptor encryptor + ) { + this(client, settings, clusterService, xContentRegistry, toolFactories, memoryFactoryMap, sdkClient, encryptor, null); + } + + public MLChatAgentRunner( + Client client, + Settings settings, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + Map toolFactories, + Map memoryFactoryMap, + SdkClient sdkClient, + Encryptor encryptor, + HookRegistry hookRegistry ) { this.client = client; this.settings = settings; @@ -157,6 +173,7 @@ public MLChatAgentRunner( this.memoryFactoryMap = memoryFactoryMap; this.sdkClient = sdkClient; this.encryptor = encryptor; + this.hookRegistry = hookRegistry; } @Override @@ -199,7 +216,8 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener for (Interaction next : r) { String question = next.getInput(); String response = next.getResponse(); - // As we store the conversation with empty response first and then update when have final answer, + // As we store the conversation with empty response first and then update when + // have final answer, // filter out those in-flight requests when run in parallel if (Strings.isNullOrEmpty(response)) { continue; @@ -223,7 +241,8 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener } params.put(CHAT_HISTORY, chatHistoryBuilder.toString()); - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added + // to input params to validate inputParams.put(CHAT_HISTORY, chatHistoryBuilder.toString()); } else { List chatHistory = new ArrayList<>(); @@ -244,7 +263,8 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener params.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); params.put(NEW_CHAT_HISTORY, String.join(", ", chatHistory) + ", "); - // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added to input params to validate + // required for MLChatAgentRunnerTest.java, it requires chatHistory to be added + // to input params to validate inputParams.put(CHAT_HISTORY, String.join(", ", chatHistory) + ", "); } } @@ -323,7 +343,7 @@ private void runReAct( StepListener lastStepListener = firstListener; StringBuilder scratchpadBuilder = new StringBuilder(); - List interactions = new CopyOnWriteArrayList<>(); + final List interactions = new CopyOnWriteArrayList<>(); StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); @@ -339,6 +359,7 @@ private void runReAct( if (finalI % 2 == 0) { MLTaskResponse llmResponse = (MLTaskResponse) output; ModelTensorOutput tmpModelTensorOutput = (ModelTensorOutput) llmResponse.getOutput(); + List llmResponsePatterns = gson.fromJson(tmpParameters.get("llm_response_pattern"), List.class); Map modelOutput = parseLLMOutput( parameters, @@ -457,6 +478,7 @@ private void runReAct( ((ActionListener) nextStepListener).onResponse(res); } } else { + // filteredOutput is the POST Tool output Object filteredOutput = filterToolOutput(lastToolParams, output); addToolOutputToAddtionalInfo(toolSpecMap, lastAction, additionalInfo, filteredOutput); @@ -485,11 +507,13 @@ private void runReAct( newPrompt.set(substitutor.replace(finalPrompt)); tmpParameters.put(PROMPT, newPrompt.get()); if (!interactions.isEmpty()) { - tmpParameters.put(INTERACTIONS, ", " + String.join(", ", interactions)); + String interactionsStr = String.join(", ", interactions); + // Set the interactions parameter - this will be processed by context management + tmpParameters.put(INTERACTIONS, ", " + interactionsStr); } sessionMsgAnswerBuilder.append(outputToOutputString(filteredOutput)); - streamingWrapper.sendToolResponse(outputToOutputString(output), sessionId, parentInteractionId); + streamingWrapper.sendToolResponse(outputToOutputString(filteredOutput), sessionId, parentInteractionId); traceTensors .add( ModelTensors @@ -521,6 +545,26 @@ private void runReAct( ); return; } + // Emit PRE_LLM hook event + if (hookRegistry != null && !interactions.isEmpty()) { + List currentToolSpecs = new ArrayList<>(toolSpecMap.values()); + ContextManagerContext contextAfterEvent = AgentContextUtil + .emitPreLLMHook(tmpParameters, interactions, currentToolSpecs, memory, hookRegistry); + + // Check if context managers actually modified the interactions + List updatedInteractions = contextAfterEvent.getToolInteractions(); + + if (updatedInteractions != null && !updatedInteractions.equals(interactions)) { + interactions.clear(); + interactions.addAll(updatedInteractions); + + // Update parameters if context manager set INTERACTIONS + String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS); + if (contextInteractions != null && !contextInteractions.isEmpty()) { + tmpParameters.put(INTERACTIONS, contextInteractions); + } + } + } ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); streamingWrapper.executeRequest(request, (ActionListener) nextStepListener); } @@ -533,8 +577,29 @@ private void runReAct( } } + // Emit PRE_LLM hook event for initial LLM call + List initialToolSpecs = new ArrayList<>(toolSpecMap.values()); + tmpParameters.put("_llm_model_id", llm.getModelId()); + if (hookRegistry != null && !interactions.isEmpty()) { + ContextManagerContext contextAfterEvent = AgentContextUtil + .emitPreLLMHook(tmpParameters, interactions, initialToolSpecs, memory, hookRegistry); + + // Check if context managers actually modified the interactions + List updatedInteractions = contextAfterEvent.getToolInteractions(); + if (updatedInteractions != null && !updatedInteractions.equals(interactions)) { + interactions.clear(); + interactions.addAll(updatedInteractions); + + // Update parameters if context manager set INTERACTIONS + String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS); + if (contextInteractions != null && !contextInteractions.isEmpty()) { + tmpParameters.put(INTERACTIONS, contextInteractions); + } + } + } ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); streamingWrapper.executeRequest(request, firstListener); + } private static List createFinalAnswerTensors(List sessionId, List lastThought) { @@ -584,7 +649,9 @@ private static void addToolOutputToAddtionalInfo( List list = (List) additionalInfo.get(toolOutputKey); list.add(outputString); } else { - additionalInfo.put(toolOutputKey, Lists.newArrayList(outputString)); + List newList = new ArrayList<>(); + newList.add(outputString); + additionalInfo.put(toolOutputKey, newList); } } } @@ -607,17 +674,29 @@ private static void runTool( ActionListener toolListener = ActionListener.wrap(r -> { if (functionCalling != null) { String outputResponse = parseResponse(filterToolOutput(toolParams, r)); + + // Emit POST_TOOL hook event after tool execution and process current tool + // output + List postToolSpecs = new ArrayList<>(toolSpecMap.values()); + String outputResponseAfterHook = AgentContextUtil + .emitPostToolHook(outputResponse, tmpParameters, postToolSpecs, null, hookRegistry) + .toString(); + List> toolResults = List - .of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", outputResponse))); + .of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", outputResponseAfterHook))); List llmMessages = functionCalling.supply(toolResults); - // TODO: support multiple tool calls at the same time so that multiple LLMMessages can be generated here + // TODO: support multiple tool calls at the same time so that multiple + // LLMMessages can be generated here interactions.add(llmMessages.getFirst().getResponse()); } else { + // Emit POST_TOOL hook event for non-function calling path + List postToolSpecs = new ArrayList<>(toolSpecMap.values()); + Object processedOutput = AgentContextUtil.emitPostToolHook(r, tmpParameters, postToolSpecs, null, hookRegistry); interactions .add( substitute( tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE), - Map.of(TOOL_CALL_ID, toolCallId, "tool_response", processTextDoc(StringUtils.toJson(r))), + Map.of(TOOL_CALL_ID, toolCallId, "tool_response", processTextDoc(StringUtils.toJson(processedOutput))), INTERACTIONS_PREFIX ) ); @@ -669,9 +748,13 @@ private static void runTool( } /** - * In each tool runs, it copies agent parameters, which is tmpParameters into a new set of parameter llmToolTmpParameters, - * after the tool runs, normally llmToolTmpParameters will be discarded, but for some special parameters like SCRATCHPAD_NOTES_KEY, - * some new llmToolTmpParameters produced by the tool run can opt to be copied back to tmpParameters to share across tools in the same interaction + * In each tool runs, it copies agent parameters, which is tmpParameters into a + * new set of parameter llmToolTmpParameters, + * after the tool runs, normally llmToolTmpParameters will be discarded, but for + * some special parameters like SCRATCHPAD_NOTES_KEY, + * some new llmToolTmpParameters produced by the tool run can opt to be copied + * back to tmpParameters to share across tools in the same interaction + * * @param tmpParameters * @param llmToolTmpParameters */ @@ -868,7 +951,7 @@ public static void returnFinalResponse( ModelTensor .builder() .name("response") - .dataAsMap(ImmutableMap.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) + .dataAsMap(Map.of("response", finalAnswer2, ADDITIONAL_INFO_FIELD, additionalInfo)) .build() ) ); @@ -938,4 +1021,5 @@ private void saveMessage( memory.save(msgTemp, parentInteractionId, traceNumber.addAndGet(1), "LLM", listener); } } + } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 8d8a854217..9ad4bed833 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -22,6 +22,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMlToolSpecs; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.MAX_ITERATION; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.saveTraceData; @@ -55,9 +56,11 @@ import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.memory.Memory; @@ -71,6 +74,7 @@ import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.remote.metadata.client.SdkClient; @@ -94,6 +98,7 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner { private final Map memoryFactoryMap; private SdkClient sdkClient; private Encryptor encryptor; + private HookRegistry hookRegistry; // flag to track if task has been updated with executor memory ids or not private boolean taskUpdated = false; private final Map taskUpdates = new HashMap<>(); @@ -165,7 +170,8 @@ public MLPlanExecuteAndReflectAgentRunner( Map toolFactories, Map memoryFactoryMap, SdkClient sdkClient, - Encryptor encryptor + Encryptor encryptor, + HookRegistry hookRegistry ) { this.client = client; this.settings = settings; @@ -175,6 +181,7 @@ public MLPlanExecuteAndReflectAgentRunner( this.memoryFactoryMap = memoryFactoryMap; this.sdkClient = sdkClient; this.encryptor = encryptor; + this.hookRegistry = hookRegistry; this.plannerPrompt = DEFAULT_PLANNER_PROMPT; this.plannerPromptTemplate = DEFAULT_PLANNER_PROMPT_TEMPLATE; this.reflectPrompt = DEFAULT_REFLECT_PROMPT; @@ -290,9 +297,6 @@ public void run(MLAgent mlAgent, Map apiParams, ActionListener> memoryFactory = memoryFactoryMap.get(memoryType); Map memoryParams = createMemoryParams( apiParams.get(USER_PROMPT_FIELD), @@ -303,7 +307,7 @@ public void run(MLAgent mlAgent, Map apiParams, ActionListener { memory.getMessages(messageHistoryLimit, ActionListener.>wrap(interactions -> { - List completedSteps = new ArrayList<>(); + final List completedSteps = new ArrayList<>(); for (Interaction interaction : interactions) { String question = interaction.getInput(); String response = interaction.getResponse(); @@ -397,8 +401,41 @@ private void executePlanningLoop( ); return; } + MLPredictionTaskRequest request; + // Planner agent doesn't use INTERACTIONS for now, reusing the INTERACTIONS to pass over + // completedSteps to context management. + // TODO should refactor the completed steps as message array format, similar to chat agent. + + allParams.put("_llm_model_id", llm.getModelId()); + if (hookRegistry != null && !completedSteps.isEmpty()) { + + Map requestParams = new HashMap<>(allParams); + requestParams.put(INTERACTIONS, ", " + String.join(", ", completedSteps)); + try { + ContextManagerContext contextAfterEvent = AgentContextUtil + .emitPreLLMHook(requestParams, completedSteps, null, memory, hookRegistry); + + // Check if context managers actually modified the interactions + List updatedSteps = contextAfterEvent.getToolInteractions(); + if (updatedSteps != null && !updatedSteps.equals(completedSteps)) { + completedSteps.clear(); + completedSteps.addAll(updatedSteps); + + // Update parameters if context manager set INTERACTIONS + String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS); + if (contextInteractions != null && !contextInteractions.isEmpty()) { + allParams.put(COMPLETED_STEPS_FIELD, contextInteractions); + // TODO should I always clear interactions after update the completed steps? + allParams.put(INTERACTIONS, ""); + } + } + } catch (Exception e) { + log.error("Failed to emit pre-LLM hook", e); + } - MLPredictionTaskRequest request = new MLPredictionTaskRequest( + } + + request = new MLPredictionTaskRequest( llm.getModelId(), RemoteInferenceMLInput .builder() @@ -454,6 +491,9 @@ private void executePlanningLoop( .inputDataset(RemoteInferenceInputDataSet.builder().parameters(reactParams).build()) .build(); + // Pass hookRegistry to internal agent execution + agentInput.setHookRegistry(hookRegistry); + MLExecuteTaskRequest executeRequest = new MLExecuteTaskRequest(FunctionName.AGENT, agentInput); client.execute(MLExecuteTaskAction.INSTANCE, executeRequest, ActionListener.wrap(executeResponse -> { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java new file mode 100644 index 0000000000..c541045aaf --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ActivationRule; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +import lombok.extern.log4j.Log4j2; + +/** + * Context manager that implements a sliding window approach for tool interactions. + * Keeps only the most recent N interactions to prevent context window overflow. + * This manager ensures proper handling of different message types while tool execution flow. + */ +@Log4j2 +public class SlidingWindowManager implements ContextManager { + + public static final String TYPE = "SlidingWindowManager"; + + // Configuration keys + private static final String MAX_MESSAGES_KEY = "max_messages"; + + // Default values + private static final int DEFAULT_MAX_MESSAGES = 20; + + private int maxMessages; + private List activationRules; + + @Override + public String getType() { + return TYPE; + } + + @Override + public void initialize(Map config) { + // Initialize configuration with defaults + this.maxMessages = parseIntegerConfig(config, MAX_MESSAGES_KEY, DEFAULT_MAX_MESSAGES); + + if (this.maxMessages <= 0) { + log.warn("Invalid max_messages value: {}, using default {}", this.maxMessages, DEFAULT_MAX_MESSAGES); + this.maxMessages = DEFAULT_MAX_MESSAGES; + } + + // Initialize activation rules from config + @SuppressWarnings("unchecked") + Map activationConfig = (Map) config.get("activation"); + this.activationRules = ActivationRuleFactory.createRules(activationConfig); + + log.info("Initialized SlidingWindowManager: maxMessages={}", maxMessages); + } + + @Override + public boolean shouldActivate(ContextManagerContext context) { + if (activationRules == null || activationRules.isEmpty()) { + return true; + } + + for (ActivationRule rule : activationRules) { + if (!rule.evaluate(context)) { + log.debug("Activation rule not satisfied: {}", rule.getDescription()); + return false; + } + } + + log.debug("All activation rules satisfied, manager will execute"); + return true; + } + + @Override + public void execute(ContextManagerContext context) { + List interactions = context.getToolInteractions(); + + if (interactions == null || interactions.isEmpty()) { + log.debug("No tool interactions to process"); + return; + } + + if (interactions.isEmpty()) { + log.debug("No string interactions found in tool interactions"); + return; + } + + int originalSize = interactions.size(); + + if (originalSize <= maxMessages) { + log.debug("Interactions size ({}) is within limit ({}), no truncation needed", originalSize, maxMessages); + return; + } + + // Find safe start point to avoid breaking tool pairs + int startIndex = findSafeStartPoint(interactions, originalSize - maxMessages); + + // Keep the most recent interactions from safe start point + List updatedInteractions = new ArrayList<>(interactions.subList(startIndex, originalSize)); + + // Update toolInteractions in context to keep only the most recent ones + context.setToolInteractions(updatedInteractions); + + // Update the _interactions parameter with smaller size of updated interactions + Map parameters = context.getParameters(); + if (parameters == null) { + parameters = new HashMap<>(); + context.setParameters(parameters); + } + parameters.put("_interactions", ", " + String.join(", ", updatedInteractions)); + + int removedMessages = originalSize - updatedInteractions.size(); + log + .info( + "Applied sliding window: kept {} most recent interactions, removed {} older interactions", + updatedInteractions.size(), + removedMessages + ); + } + + private int parseIntegerConfig(Map config, String key, int defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + return Integer.parseInt((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid integer value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } + + /** + * Find a safe start point that doesn't break assistant-tool message pairs + * Same logic as SummarizationManager but for finding start point + */ + private int findSafeStartPoint(List interactions, int targetStartPoint) { + if (targetStartPoint <= 0) { + return 0; + } + if (targetStartPoint >= interactions.size()) { + return interactions.size(); + } + + int startPoint = targetStartPoint; + + while (startPoint < interactions.size()) { + try { + String messageAtStart = interactions.get(startPoint); + + // Oldest message cannot be a toolResult because it needs a toolUse preceding it + boolean hasToolResult = messageAtStart.contains("toolResult"); + + // Oldest message can be a toolUse only if a toolResult immediately follows it + boolean hasToolUse = messageAtStart.contains("toolUse"); + boolean nextHasToolResult = false; + if (startPoint + 1 < interactions.size()) { + nextHasToolResult = interactions.get(startPoint + 1).contains("toolResult"); + } + + if (hasToolResult || (hasToolUse && startPoint + 1 < interactions.size() && !nextHasToolResult)) { + startPoint++; + } else { + break; + } + + } catch (Exception e) { + log.warn("Error checking message at index {}: {}", startPoint, e.getMessage()); + startPoint++; + } + } + + return startPoint; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java new file mode 100644 index 0000000000..75128e266a --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java @@ -0,0 +1,435 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import static java.lang.Math.min; +import static org.opensearch.ml.common.FunctionName.REMOTE; +import static org.opensearch.ml.common.utils.StringUtils.processTextDoc; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.INTERACTIONS; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ActivationRule; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.transport.client.Client; + +import com.jayway.jsonpath.JsonPath; +import com.jayway.jsonpath.PathNotFoundException; + +import lombok.extern.log4j.Log4j2; + +/** + * Context manager that implements summarization approach for tool interactions. + * Summarizes older interactions while preserving recent ones to manage context + * window. + */ +@Log4j2 +public class SummarizationManager implements ContextManager { + + public static final String TYPE = "SummarizationManager"; + + // Configuration keys + private static final String SUMMARY_RATIO_KEY = "summary_ratio"; + private static final String PRESERVE_RECENT_MESSAGES_KEY = "preserve_recent_messages"; + private static final String SUMMARIZATION_MODEL_ID_KEY = "summarization_model_id"; + private static final String SUMMARIZATION_SYSTEM_PROMPT_KEY = "summarization_system_prompt"; + + // Default values + private static final double DEFAULT_SUMMARY_RATIO = 0.3; + private static final int DEFAULT_PRESERVE_RECENT_MESSAGES = 10; + private static final String DEFAULT_SUMMARIZATION_PROMPT = + "You are a interactions summarization agent. Summarize the provided interactions concisely while preserving key information and context."; + + protected double summaryRatio; + protected int preserveRecentMessages; + protected String summarizationModelId; + protected String summarizationSystemPrompt; + protected List activationRules; + private Client client; + + public SummarizationManager(Client client) { + this.client = client; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public void initialize(Map config) { + this.summaryRatio = parseDoubleConfig(config, SUMMARY_RATIO_KEY, DEFAULT_SUMMARY_RATIO); + this.preserveRecentMessages = parseIntegerConfig(config, PRESERVE_RECENT_MESSAGES_KEY, DEFAULT_PRESERVE_RECENT_MESSAGES); + this.summarizationModelId = (String) config.get(SUMMARIZATION_MODEL_ID_KEY); + this.summarizationSystemPrompt = (String) config.getOrDefault(SUMMARIZATION_SYSTEM_PROMPT_KEY, DEFAULT_SUMMARIZATION_PROMPT); + + // Validate summary ratio + if (summaryRatio < 0.1 || summaryRatio > 0.8) { + log.warn("Invalid summary_ratio value: {}, using default {}", summaryRatio, DEFAULT_SUMMARY_RATIO); + this.summaryRatio = DEFAULT_SUMMARY_RATIO; + } + + // Initialize activation rules from config + @SuppressWarnings("unchecked") + Map activationConfig = (Map) config.get("activation"); + this.activationRules = ActivationRuleFactory.createRules(activationConfig); + + log.info("Initialized SummarizationManager: summaryRatio={}, preserveRecentMessages={}", summaryRatio, preserveRecentMessages); + } + + @Override + public boolean shouldActivate(ContextManagerContext context) { + if (activationRules == null || activationRules.isEmpty()) { + return true; + } + + for (ActivationRule rule : activationRules) { + if (!rule.evaluate(context)) { + log.debug("Activation rule not satisfied: {}", rule.getDescription()); + return false; + } + } + + log.debug("All activation rules satisfied, manager will execute"); + return true; + } + + @Override + public void execute(ContextManagerContext context) { + List interactions = context.getToolInteractions(); + + if (interactions == null || interactions.isEmpty()) { + return; + } + + if (interactions.isEmpty()) { + log.debug("No string interactions found in tool interactions"); + return; + } + + int totalMessages = interactions.size(); + + // Calculate how many messages to summarize + int messagesToSummarizeCount = Math.max(1, (int) (totalMessages * summaryRatio)); + + // Ensure we don't summarize recent messages + messagesToSummarizeCount = min(messagesToSummarizeCount, totalMessages - preserveRecentMessages); + + if (messagesToSummarizeCount <= 0) { + return; + } + + // Find a safe cut point that doesn't break assistant-tool pairs + int safeCutPoint = findSafeCutPoint(interactions, messagesToSummarizeCount); + + if (safeCutPoint <= 0) { + return; + } + + // Extract messages to summarize and remaining messages + List messagesToSummarize = new ArrayList<>(interactions.subList(0, safeCutPoint)); + List remainingMessages = new ArrayList<>(interactions.subList(safeCutPoint, totalMessages)); + + // Get model ID + String modelId = summarizationModelId; + if (modelId == null) { + Map parameters = context.getParameters(); + if (parameters != null) { + modelId = (String) parameters.get("_llm_model_id"); + } + } + + if (modelId == null) { + log.error("No model ID available for summarization"); + return; + } + + // Prepare summarization parameters + Map summarizationParameters = new HashMap<>(); + summarizationParameters.put("prompt", "Help summarize the following" + StringUtils.toJson(String.join(",", messagesToSummarize))); + summarizationParameters.put("system_prompt", summarizationSystemPrompt); + + executeSummarization(context, modelId, summarizationParameters, safeCutPoint, remainingMessages, interactions); + } + + protected void executeSummarization( + ContextManagerContext context, + String modelId, + Map summarizationParameters, + int messagesToSummarizeCount, + List remainingMessages, + List originalInteractions + ) { + CountDownLatch latch = new CountDownLatch(1); + + try { + // Create ML input dataset for remote inference + MLInputDataset inputDataset = RemoteInferenceInputDataSet.builder().parameters(summarizationParameters).build(); + + // Create ML input + MLInput mlInput = MLInput.builder().algorithm(REMOTE).inputDataset(inputDataset).build(); + + // Create prediction request + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().modelId(modelId).mlInput(mlInput).build(); + + // Execute prediction + ActionListener listener = ActionListener.wrap(response -> { + try { + String summary = extractSummaryFromResponse(response, context); + processSummarizationResult(context, summary, messagesToSummarizeCount, remainingMessages, originalInteractions); + } catch (Exception e) { + // Fallback to default behavior + processSummarizationResult( + context, + "Summarized " + messagesToSummarizeCount + " previous tool interactions", + messagesToSummarizeCount, + remainingMessages, + originalInteractions + ); + } finally { + latch.countDown(); + } + }, e -> { + try { + // Fallback to default behavior + processSummarizationResult( + context, + "Summarized " + messagesToSummarizeCount + " previous tool interactions", + messagesToSummarizeCount, + remainingMessages, + originalInteractions + ); + } finally { + latch.countDown(); + } + }); + + client.execute(MLPredictionTaskAction.INSTANCE, request, listener); + + // Wait for summarization to complete (30 second timeout) + latch.await(30, TimeUnit.SECONDS); + + } catch (Exception e) { + // Fallback to default behavior + processSummarizationResult( + context, + "Summarized " + messagesToSummarizeCount + " previous interactions", + messagesToSummarizeCount, + remainingMessages, + originalInteractions + ); + } + } + + protected void processSummarizationResult( + ContextManagerContext context, + String summary, + int messagesToSummarizeCount, + List remainingMessages, + List originalInteractions + ) { + try { + // Create summarized interaction + String summarizedInteraction = "{\"role\":\"assistant\",\"content\":\"Summarized previous interactions: " + + processTextDoc(summary) + + "\"}"; + + // Update interactions: summary + remaining messages + List updatedInteractions = new ArrayList<>(); + updatedInteractions.add(summarizedInteraction); + updatedInteractions.addAll(remainingMessages); + + // Update toolInteractions in context + context.setToolInteractions(updatedInteractions); + + // Update parameters + Map parameters = context.getParameters(); + if (parameters == null) { + parameters = new HashMap<>(); + } + parameters.put(INTERACTIONS, ", " + String.join(", ", updatedInteractions)); + context.setParameters(parameters); + + log + .info( + "Summarization completed: {} messages summarized, {} messages preserved", + messagesToSummarizeCount, + remainingMessages.size() + ); + + } catch (Exception e) { + log.error("Failed to process summarization result", e); + } + } + + private String extractSummaryFromResponse(MLTaskResponse response, ContextManagerContext context) { + try { + MLOutput output = response.getOutput(); + if (output instanceof ModelTensorOutput) { + ModelTensorOutput tensorOutput = (ModelTensorOutput) output; + List mlModelOutputs = tensorOutput.getMlModelOutputs(); + + if (mlModelOutputs != null && !mlModelOutputs.isEmpty()) { + List tensors = mlModelOutputs.get(0).getMlModelTensors(); + if (tensors != null && !tensors.isEmpty()) { + Map dataAsMap = tensors.get(0).getDataAsMap(); + + // Use LLM_RESPONSE_FILTER from agent configuration if available + Map parameters = context.getParameters(); + if (parameters != null + && parameters.containsKey(LLM_RESPONSE_FILTER) + && !parameters.get(LLM_RESPONSE_FILTER).isEmpty()) { + try { + String responseFilter = parameters.get(LLM_RESPONSE_FILTER); + Object filteredResponse = JsonPath.read(dataAsMap, responseFilter); + if (filteredResponse instanceof String) { + String result = ((String) filteredResponse).trim(); + return result; + } else { + String result = StringUtils.toJson(filteredResponse); + return result; + } + } catch (PathNotFoundException e) { + // Fall back to default parsing + } catch (Exception e) { + // Fall back to default parsing + } + } + + // Fallback to default parsing if no filter or filter fails + if (dataAsMap.size() == 1 && dataAsMap.containsKey("response")) { + Object responseObj = dataAsMap.get("response"); + if (responseObj instanceof String) { + return ((String) responseObj).trim(); + } + } + + // Last resort: return JSON representation + return StringUtils.toJson(dataAsMap); + } + } + } + } catch (Exception e) { + log.error("Failed to extract summary from response", e); + } + + return "Summary generation failed"; + } + + private double parseDoubleConfig(Map config, String key, double defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Double) { + return (Double) value; + } else if (value instanceof Number) { + return ((Number) value).doubleValue(); + } else if (value instanceof String) { + return Double.parseDouble((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid double value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } + + private int parseIntegerConfig(Map config, String key, int defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + return Integer.parseInt((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid integer value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } + + /** + * Find a safe cut point that doesn't break assistant-tool message pairs + * Exact same logic as Strands agent + */ + private int findSafeCutPoint(List interactions, int targetCutPoint) { + if (targetCutPoint >= interactions.size()) { + return targetCutPoint; + } + // // the current agent logic is when odd number it's tool called result and even number is tool input, should always summarize for + // pairs, so the targetCutPoint needs to be even + // if (targetCutPoint%2==0){ + // return targetCutPoint; + // } else { + // return min(targetCutPoint+1,interactions.size()); + // } + int splitPoint = targetCutPoint; + + while (splitPoint < interactions.size()) { + try { + String messageAtSplit = interactions.get(splitPoint); + + // Oldest message cannot be a toolResult because it needs a toolUse preceding it + boolean hasToolResult = (messageAtSplit.contains("toolResult") || messageAtSplit.contains("tool_call_id")); + + // Oldest message can be a toolUse only if a toolResult immediately follows it + boolean hasToolUse = messageAtSplit.contains("toolUse"); + boolean nextHasToolResult = false; + // TODO we need better way to handle the tool result based on the llm interfaces. + if (splitPoint + 1 < interactions.size()) { + nextHasToolResult = (interactions.get(splitPoint + 1).contains("toolResult") + || messageAtSplit.contains("tool_call_id")); + } + + if (hasToolResult || (hasToolUse && splitPoint + 1 < interactions.size() && !nextHasToolResult)) { + splitPoint++; + } else { + break; + } + + } catch (Exception e) { + log.warn("Error checking message at index {}: {}", splitPoint, e.getMessage()); + splitPoint++; + } + } + + return splitPoint; + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java new file mode 100644 index 0000000000..4fa97c156d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/ToolsOutputTruncateManager.java @@ -0,0 +1,134 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import java.util.List; +import java.util.Map; + +import org.opensearch.ml.common.contextmanager.ActivationRule; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +import lombok.extern.log4j.Log4j2; + +/** + * Context manager that truncates tool output to prevent context window overflow. + * This manager processes the current tool output and applies length limits. + */ +@Log4j2 +public class ToolsOutputTruncateManager implements ContextManager { + + public static final String TYPE = "ToolsOutputTruncateManager"; + + // Configuration keys + private static final String MAX_OUTPUT_LENGTH_KEY = "max_output_length"; + + // Default values + private static final int DEFAULT_MAX_OUTPUT_LENGTH = 40000; + + private int maxOutputLength; + private List activationRules; + + @Override + public String getType() { + return TYPE; + } + + @Override + public void initialize(Map config) { + // Initialize configuration with defaults + this.maxOutputLength = parseIntegerConfig(config, MAX_OUTPUT_LENGTH_KEY, DEFAULT_MAX_OUTPUT_LENGTH); + + if (this.maxOutputLength <= 0) { + log.warn("Invalid max_output_length value: {}, using default {}", this.maxOutputLength, DEFAULT_MAX_OUTPUT_LENGTH); + this.maxOutputLength = DEFAULT_MAX_OUTPUT_LENGTH; + } + + // Initialize activation rules from config + @SuppressWarnings("unchecked") + Map activationConfig = (Map) config.get("activation"); + this.activationRules = ActivationRuleFactory.createRules(activationConfig); + + log.info("Initialized ToolsOutputTruncateManager: maxOutputLength={}", maxOutputLength); + } + + @Override + public boolean shouldActivate(ContextManagerContext context) { + if (activationRules == null || activationRules.isEmpty()) { + return true; + } + + for (ActivationRule rule : activationRules) { + if (!rule.evaluate(context)) { + log.debug("Activation rule not satisfied: {}", rule.getDescription()); + return false; + } + } + + log.debug("All activation rules satisfied, manager will execute"); + return true; + } + + @Override + public void execute(ContextManagerContext context) { + // Process current tool output from parameters + Map parameters = context.getParameters(); + if (parameters == null) { + log.debug("No parameters available for tool output truncation"); + return; + } + + Object currentToolOutput = parameters.get("_current_tool_output"); + if (currentToolOutput == null) { + log.debug("No current tool output to process"); + return; + } + + String outputString = currentToolOutput.toString(); + int originalLength = outputString.length(); + + if (originalLength <= maxOutputLength) { + log.debug("Tool output length ({}) is within limit ({}), no truncation needed", originalLength, maxOutputLength); + return; + } + + // Truncate the output + String truncatedOutput = outputString.substring(0, maxOutputLength); + + // Add truncation indicator + truncatedOutput += "... [Output truncated - original length: " + originalLength + " characters]"; + + // Update the current tool output in parameters + parameters.put("_current_tool_output", truncatedOutput); + + int truncatedLength = truncatedOutput.length(); + log.info("Tool output truncated: original length {} -> truncated length {}", originalLength, truncatedLength); + } + + private int parseIntegerConfig(Map config, String key, int defaultValue) { + Object value = config.get(key); + if (value == null) { + return defaultValue; + } + + try { + if (value instanceof Integer) { + return (Integer) value; + } else if (value instanceof Number) { + return ((Number) value).intValue(); + } else if (value instanceof String) { + return Integer.parseInt((String) value); + } else { + log.warn("Invalid type for config key '{}': {}, using default {}", key, value.getClass().getSimpleName(), defaultValue); + return defaultValue; + } + } catch (NumberFormatException e) { + log.warn("Invalid integer value for config key '{}': {}, using default {}", key, value, defaultValue); + return defaultValue; + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java index 344e28e487..2e8c612c43 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java @@ -5,1617 +5,300 @@ package org.opensearch.ml.engine.algorithms.agent; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.when; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; -import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD; -import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; -import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE; -import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED; -import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.MEMORY_ID; -import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.QUESTION; -import static org.opensearch.ml.engine.algorithms.agent.MLAgentExecutor.REGENERATE_INTERACTION_ID; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; -import java.io.IOException; -import java.net.InetAddress; import java.time.Instant; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; -import javax.naming.Context; - -import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Captor; import org.mockito.Mock; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchException; -import org.opensearch.ResourceNotFoundException; -import org.opensearch.Version; -import org.opensearch.action.get.GetRequest; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexResponse; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.Metadata; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; -import org.opensearch.ml.common.agent.MLMemorySpec; -import org.opensearch.ml.common.agent.MLToolSpec; -import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.memory.Memory; -import org.opensearch.ml.common.output.MLTaskOutput; import org.opensearch.ml.common.output.Output; -import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.spi.tools.Tool; -import org.opensearch.ml.engine.memory.ConversationIndexMemory; -import org.opensearch.ml.engine.memory.MLMemoryManager; -import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; -import org.opensearch.ml.memory.action.conversation.GetInteractionAction; -import org.opensearch.ml.memory.action.conversation.GetInteractionResponse; +import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.remote.metadata.client.SdkClient; -import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; -import com.google.gson.Gson; - -import software.amazon.awssdk.utils.ImmutableMap; - +@SuppressWarnings({ "rawtypes" }) public class MLAgentExecutorTest { @Mock private Client client; - SdkClient sdkClient; - private Settings settings; - @Mock - private ClusterService clusterService; + @Mock - private ClusterState clusterState; + private SdkClient sdkClient; + @Mock - private Metadata metadata; + private ClusterService clusterService; + @Mock private NamedXContentRegistry xContentRegistry; + @Mock - private Map toolFactories; - @Mock - private Map memoryMap; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock - private IndexResponse indexResponse; + private Encryptor encryptor; + @Mock private ThreadPool threadPool; + private ThreadContext threadContext; - @Mock - private Context context; - @Mock - private ConversationIndexMemory.Factory mockMemoryFactory; - @Mock - private ActionListener agentActionListener; - @Mock - private MLAgentRunner mlAgentRunner; @Mock - private ConversationIndexMemory memory; - @Mock - private MLMemoryManager memoryManager; - private MLAgentExecutor mlAgentExecutor; + private ThreadContext.StoredContext storedContext; @Mock - private MLFeatureEnabledSetting mlFeatureEnabledSetting; - - @Captor - private ArgumentCaptor objectCaptor; + private TransportChannel channel; - @Captor - private ArgumentCaptor exceptionCaptor; + @Mock + private ActionListener listener; - private DiscoveryNode localNode = new DiscoveryNode( - "mockClusterManagerNodeId", - "mockClusterManagerNodeId", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); + @Mock + private GetResponse getResponse; - MLAgent mlAgent; + private MLAgentExecutor mlAgentExecutor; + private Map toolFactories; + private Map memoryFactoryMap; + private Settings settings; @Before - @SuppressWarnings("unchecked") public void setup() { MockitoAnnotations.openMocks(this); + settings = Settings.builder().build(); - sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); + toolFactories = new HashMap<>(); + memoryFactoryMap = new HashMap<>(); threadContext = new ThreadContext(settings); - memoryMap = ImmutableMap.of("memoryType", mockMemoryFactory); - Mockito.doAnswer(invocation -> { - MLMemorySpec mlMemorySpec = MLMemorySpec.builder().type("memoryType").build(); - MLAgent mlAgent = MLAgent.builder().name("agent").memory(mlMemorySpec).type("flow").build(); - XContentBuilder content = mlAgent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - ActionListener listener = invocation.getArgument(1); - GetResponse getResponse = Mockito.mock(GetResponse.class); - Mockito.when(getResponse.isExists()).thenReturn(true); - Mockito.when(getResponse.getSourceAsBytesRef()).thenReturn(BytesReference.bytes(content)); - listener.onResponse(getResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - Mockito.when(clusterService.state()).thenReturn(clusterState); - Mockito.when(clusterState.metadata()).thenReturn(metadata); - when(clusterService.localNode()).thenReturn(localNode); - Mockito.when(metadata.hasIndex(Mockito.anyString())).thenReturn(true); - Mockito.when(memory.getMemoryManager()).thenReturn(memoryManager); + when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - when(this.clusterService.getSettings()).thenReturn(settings); - when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(ML_COMMONS_MCP_CONNECTOR_ENABLED))); - - // Mock MLFeatureEnabledSetting when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(false); - when(mlFeatureEnabledSetting.isMcpConnectorEnabled()).thenReturn(true); - - settings = Settings.builder().build(); - mlAgentExecutor = Mockito - .spy( - new MLAgentExecutor( - client, - sdkClient, - settings, - clusterService, - xContentRegistry, - toolFactories, - memoryMap, - mlFeatureEnabledSetting, - null - ) - ); - - } - - @Test - public void test_NoAgentIndex() { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - Mockito.when(metadata.hasIndex(Mockito.anyString())).thenReturn(false); - - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof ResourceNotFoundException); - Assert.assertEquals(exception.getMessage(), "Agent index not found"); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NullInput_ThrowsException() { - mlAgentExecutor.execute(null, agentActionListener); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NonAgentInput_ThrowsException() { - Input input = new Input() { - @Override - public FunctionName getFunctionName() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return null; - } - }; - mlAgentExecutor.execute(input, agentActionListener); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NonInputData_ThrowsException() { - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, null); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - } - - @Test(expected = IllegalArgumentException.class) - public void test_NonInputParas_ThrowsException() { - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(null).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, inputDataSet); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - } - - @Test - public void test_HappyCase_ReturnsResult() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); - } - - @Test - public void test_AgentRunnerReturnsListOfModelTensor_ReturnsResult() throws IOException { - ModelTensor modelTensor1 = ModelTensor.builder().name("response1").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensor modelTensor2 = ModelTensor.builder().name("response2").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - List response = Arrays.asList(modelTensor1, modelTensor2); - Mockito.doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); - listener.onResponse(response); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(2, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(response, output.getMlModelOutputs().get(0).getMlModelTensors()); - } - - @Test - public void test_AgentRunnerReturnsListOfModelTensors_ReturnsResult() throws IOException { - ModelTensor modelTensor1 = ModelTensor.builder().name("response1").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors1 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor1)).build(); - ModelTensor modelTensor2 = ModelTensor.builder().name("response2").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors2 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor2)).build(); - List response = Arrays.asList(modelTensors1, modelTensors2); - Mockito.doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); - listener.onResponse(response); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(2, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(Arrays.asList(modelTensor1, modelTensor2), output.getMlModelOutputs().get(0).getMlModelTensors()); - } - - @Test - public void test_AgentRunnerReturnsListOfString_ReturnsResult() throws IOException { - List response = Arrays.asList("response1", "response2"); - Mockito.doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); - listener.onResponse(response); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Gson gson = new Gson(); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(gson.toJson(response), output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getResult()); - } - - @Test - public void test_AgentRunnerReturnsString_ReturnsResult() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse("response"); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals("response", output.getMlModelOutputs().get(0).getMlModelTensors().get(0).getResult()); - } - - @Test - public void test_AgentRunnerReturnsModelTensorOutput_ReturnsResult() throws IOException { - ModelTensor modelTensor1 = ModelTensor.builder().name("response1").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors1 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor1)).build(); - ModelTensor modelTensor2 = ModelTensor.builder().name("response2").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ModelTensors modelTensors2 = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor2)).build(); - List modelTensorsList = Arrays.asList(modelTensors1, modelTensors2); - ModelTensorOutput modelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(modelTensorsList).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensorOutput); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(2, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(Arrays.asList(modelTensor1, modelTensor2), output.getMlModelOutputs().get(0).getMlModelTensors()); - } - - @Test - public void test_CreateConversation_ReturnsResult() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ConversationIndexMemory memory = Mockito.mock(ConversationIndexMemory.class); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - CreateInteractionResponse interaction = Mockito.mock(CreateInteractionResponse.class); - Mockito.when(interaction.getId()).thenReturn("interaction_id"); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onResponse(interaction); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); - } - - @Test - public void test_Regenerate_Validation() throws IOException { - Map params = new HashMap<>(); - params.put(REGENERATE_INTERACTION_ID, "foo"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof IllegalArgumentException); - Assert.assertEquals(exception.getMessage(), "A memory ID must be provided to regenerate."); - } - - @Test - public void test_Regenerate_GetOriginalInteraction() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - CreateInteractionResponse interaction = Mockito.mock(CreateInteractionResponse.class); - Mockito.when(interaction.getId()).thenReturn("interaction_id"); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onResponse(interaction); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(Boolean.TRUE); - return null; - }).when(memoryManager).deleteInteractionAndTrace(Mockito.anyString(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - GetInteractionResponse interactionResponse = Mockito.mock(GetInteractionResponse.class); - Interaction mockInteraction = Mockito.mock(Interaction.class); - Mockito.when(mockInteraction.getInput()).thenReturn("regenerate question"); - Mockito.when(interactionResponse.getInteraction()).thenReturn(mockInteraction); - listener.onResponse(interactionResponse); - return null; - }).when(client).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - - String interactionId = "bar-interaction"; - Map params = new HashMap<>(); - params.put(MEMORY_ID, "foo-memory"); - params.put(REGENERATE_INTERACTION_ID, interactionId); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(client, times(1)).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - Assert.assertEquals(params.get(QUESTION), "regenerate question"); - // original interaction got deleted - Mockito.verify(memoryManager, times(1)).deleteInteractionAndTrace(Mockito.eq(interactionId), Mockito.any()); - } - - @Test - public void test_Regenerate_OriginalInteraction_NotExist() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - ConversationIndexMemory memory = Mockito.mock(ConversationIndexMemory.class); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - // Extract the ActionListener argument from the method invocation - ActionListener listener = invocation.getArgument(1); - // Trigger the onResponse method of the ActionListener with the mock response - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - CreateInteractionResponse interaction = Mockito.mock(CreateInteractionResponse.class); - Mockito.when(interaction.getId()).thenReturn("interaction_id"); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onResponse(interaction); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new ResourceNotFoundException("Interaction bar-interaction not found")); - return null; - }).when(client).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "foo-memory"); - params.put(REGENERATE_INTERACTION_ID, "bar-interaction"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(client, times(1)).execute(Mockito.eq(GetInteractionAction.INSTANCE), Mockito.any(), Mockito.any()); - Assert.assertNull(params.get(QUESTION)); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof ResourceNotFoundException); - Assert.assertEquals(exception.getMessage(), "Interaction bar-interaction not found"); - } - - @Test - public void test_CreateFlowAgent() { - MLAgent mlAgent = MLAgent.builder().name("test_agent").type("flow").build(); - MLAgentRunner mlAgentRunner = mlAgentExecutor.getAgentRunner(mlAgent); - Assert.assertTrue(mlAgentRunner instanceof MLFlowAgentRunner); - } - - @Test - public void test_CreateChatAgent() { - LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); - MLAgent mlAgent = MLAgent.builder().name("test_agent").type(MLAgentType.CONVERSATIONAL.name()).llm(llmSpec).build(); - MLAgentRunner mlAgentRunner = mlAgentExecutor.getAgentRunner(mlAgent); - Assert.assertTrue(mlAgentRunner instanceof MLChatAgentRunner); - } - - @Test(expected = IllegalArgumentException.class) - public void test_InvalidAgent_ThrowsException() { - MLAgent mlAgent = MLAgent.builder().name("test_agent").type("illegal").build(); - mlAgentExecutor.getAgentRunner(mlAgent); - } - - @Test - public void test_GetModel_ThrowsException() { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException()); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_GetModelDoesNotExist_ThrowsException() { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - GetResponse getResponse = Mockito.mock(GetResponse.class); - Mockito.when(getResponse.isExists()).thenReturn(false); - listener.onResponse(getResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_CreateConversationFailure_ThrowsException() { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(new RuntimeException()); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - Map params = new HashMap<>(); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_CreateInteractionFailure_ThrowsException() { - ConversationIndexMemory memory = Mockito.mock(ConversationIndexMemory.class); - Mockito.doAnswer(invocation -> { - ActionListener responseActionListener = invocation.getArgument(4); - responseActionListener.onFailure(new RuntimeException()); - return null; - }).when(memory).save(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - Mockito.when(memory.getConversationId()).thenReturn("conversation_id"); - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - Map params = new HashMap<>(); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_AgentRunnerFailure_ReturnsResult() { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException()); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(getAgentMLInput(), agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_AsyncMode_ReturnsTaskId() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").result("test").build(); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task_id", 1, 0, 2, true); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(client).index(any(), any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput result = (MLTaskOutput) objectCaptor.getValue(); - - Assert.assertEquals("task_id", result.getTaskId()); - Assert.assertEquals("RUNNING", result.getStatus()); - } - - @Test - public void test_AsyncMode_IndexTask_failure() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").result("test").build(); - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onFailure(new Exception("Index Not Found")); - return null; - }).when(client).index(any(), any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertNotNull(exceptionCaptor.getValue()); - } - - @Test - public void test_mcp_connector_requires_mcp_connector_enabled() throws IOException { - // Create an MLAgent with MCP connectors in parameters - Map parameters = new HashMap<>(); - parameters.put(MCP_CONNECTORS_FIELD, "[{\"connector_id\": \"test-connector\"}]"); - - MLAgent mlAgentWithMcpConnectors = new MLAgent( - "test", - MLAgentType.FLOW.name(), - "test", - new LLMSpec("test_model", Map.of("test_key", "test_value")), - Collections.emptyList(), - parameters, - new MLMemorySpec("memoryType", "123", 0), - Instant.EPOCH, - Instant.EPOCH, - "test", - false, - null + // Mock ClusterService for the agent index check + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(false); // Simulate index not found + + mlAgentExecutor = new MLAgentExecutor( + client, + sdkClient, + settings, + clusterService, + xContentRegistry, + toolFactories, + memoryFactoryMap, + mlFeatureEnabledSetting, + encryptor ); - - // Create GetResponse with the MLAgent that has MCP connectors - XContentBuilder content = mlAgentWithMcpConnectors.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "test-agent-id", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse agentGetResponse = new GetResponse(getResult); - - // Create a new MLAgentExecutor with MCP connector disabled - MLFeatureEnabledSetting disabledMcpSetting = Mockito.mock(MLFeatureEnabledSetting.class); - when(disabledMcpSetting.isMultiTenancyEnabled()).thenReturn(false); - when(disabledMcpSetting.isMcpConnectorEnabled()).thenReturn(false); - - MLAgentExecutor mlAgentExecutorWithDisabledMcp = Mockito - .spy( - new MLAgentExecutor( - client, - sdkClient, - settings, - clusterService, - xContentRegistry, - toolFactories, - memoryMap, - disabledMcpSetting, - null - ) - ); - - // Mock the agent get response - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - // Mock the agent runner - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutorWithDisabledMcp).getAgentRunner(Mockito.any()); - - // Execute the agent - mlAgentExecutorWithDisabledMcp.execute(getAgentMLInput(), agentActionListener); - - // Verify that the execution fails with the correct error message - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof OpenSearchException); - Assert.assertEquals(exception.getMessage(), ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE); } @Test - public void test_query_planning_agentic_search_enabled() throws IOException { - // Create an MLAgent with QueryPlanningTool - MLAgent mlAgentWithQueryPlanning = new MLAgent( - "test", - MLAgentType.FLOW.name(), - "test", - new LLMSpec("test_model", Map.of("test_key", "test_value")), - List - .of( - new MLToolSpec( - "QueryPlanningTool", - "QueryPlanningTool", - "QueryPlanningTool", - Collections.emptyMap(), - Collections.emptyMap(), - false, - Collections.emptyMap(), - null, - null - ) - ), - Map.of("test", "test"), - new MLMemorySpec("memoryType", "123", 0), - Instant.EPOCH, - Instant.EPOCH, - "test", - false, - null - ); - - // Create GetResponse with the MLAgent that has QueryPlanningTool - XContentBuilder content = mlAgentWithQueryPlanning.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "test-agent-id", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse agentGetResponse = new GetResponse(getResult); - - // Create a new MLAgentExecutor with agentic search enabled - MLFeatureEnabledSetting enabledSearchSetting = Mockito.mock(MLFeatureEnabledSetting.class); - when(enabledSearchSetting.isMultiTenancyEnabled()).thenReturn(false); - when(enabledSearchSetting.isMcpConnectorEnabled()).thenReturn(true); - - MLAgentExecutor mlAgentExecutorWithEnabledSearch = Mockito - .spy( - new MLAgentExecutor( - client, - sdkClient, - settings, - clusterService, - xContentRegistry, - toolFactories, - memoryMap, - enabledSearchSetting, - null - ) - ); - - // Mock the agent get response - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - // Mock the agent runner - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutorWithEnabledSearch).getAgentRunner(Mockito.any()); - - // Mock successful execution - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - // Execute the agent - mlAgentExecutorWithEnabledSearch.execute(getAgentMLInput(), agentActionListener); - - // Verify that the execution succeeds - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); - } - - private AgentMLInput getAgentMLInput() { - Map params = new HashMap<>(); - params.put(MLAgentExecutor.MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parentInteractionId"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - return new AgentMLInput("test", null, FunctionName.AGENT, dataset); - } - - public GetResponse prepareMLAgent(String agentId, boolean isHidden, String tenantId) throws IOException { - - mlAgent = new MLAgent( - "test", - MLAgentType.CONVERSATIONAL.name(), - "test", - new LLMSpec("test_model", Map.of("test_key", "test_value")), - List - .of( - new MLToolSpec( - "memoryType", - "test", - "test", - Collections.emptyMap(), - Collections.emptyMap(), - false, - Collections.emptyMap(), - null, - null - ) - ), - Map.of("test", "test"), - new MLMemorySpec("memoryType", "123", 0), - Instant.EPOCH, - Instant.EPOCH, - "test", - isHidden, - tenantId - ); - - XContentBuilder content = mlAgent.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", agentId, 111l, 111l, 111l, true, bytesReference, null, null); - return new GetResponse(getResult); + public void testConstructor() { + assertNotNull(mlAgentExecutor); + assertEquals(client, mlAgentExecutor.getClient()); + assertEquals(settings, mlAgentExecutor.getSettings()); + assertEquals(clusterService, mlAgentExecutor.getClusterService()); + assertEquals(xContentRegistry, mlAgentExecutor.getXContentRegistry()); + assertEquals(toolFactories, mlAgentExecutor.getToolFactories()); + assertEquals(memoryFactoryMap, mlAgentExecutor.getMemoryFactoryMap()); + assertEquals(mlFeatureEnabledSetting, mlAgentExecutor.getMlFeatureEnabledSetting()); + assertEquals(encryptor, mlAgentExecutor.getEncryptor()); } @Test - public void test_BothParentAndRegenerateInteractionId_ThrowsException() throws IOException { - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Map params = new HashMap<>(); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent-123"); - params.put(REGENERATE_INTERACTION_ID, "regenerate-456"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); + public void testOnMultiTenancyEnabledChanged() { + mlAgentExecutor.onMultiTenancyEnabledChanged(true); + assertTrue(mlAgentExecutor.getIsMultiTenancyEnabled()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertTrue(exception instanceof IllegalArgumentException); - Assert - .assertEquals( - exception.getMessage(), - "Provide either `parent_interaction_id` to update an existing interaction, or `regenerate_interaction_id` to create a new one." - ); + mlAgentExecutor.onMultiTenancyEnabledChanged(false); + assertFalse(mlAgentExecutor.getIsMultiTenancyEnabled()); } @Test - public void test_ExistingConversation_WithMemoryAndParentInteractionId() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - // Mock memory factory for existing conversation - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("existing-memory"), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "existing-memory"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "existing-parent"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); + public void testExecuteWithWrongInputType() { + // Test with non-AgentMLInput - create a mock Input that's not AgentMLInput + Input wrongInput = mock(Input.class); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - // Verify memory factory was called with null question and existing memory_id - Mockito.verify(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("existing-memory"), Mockito.any(), Mockito.any()); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); + try { + mlAgentExecutor.execute(wrongInput, listener, channel); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("wrong input", exception.getMessage()); + } } @Test - public void test_AgentFailure_UpdatesInteractionWithFailure() throws IOException { - RuntimeException testException = new RuntimeException("Agent execution failed"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(testException); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - // Mock memory factory for existing conversation - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "test-memory"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); + public void testExecuteWithNullInputDataSet() { + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, null); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - // Verify failure was propagated to listener - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Assert.assertEquals(testException, exceptionCaptor.getValue()); - - // Verify interaction was updated with failure message - ArgumentCaptor> updateCaptor = ArgumentCaptor.forClass(Map.class); - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent"), updateCaptor.capture(), Mockito.any()); - Map updateContent = updateCaptor.getValue(); - Assert.assertTrue(updateContent.get("response").toString().contains("Agent execution failed")); + try { + mlAgentExecutor.execute(agentInput, listener, channel); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("Agent input data can not be empty.", exception.getMessage()); + } } @Test - public void test_ExistingConversation_MemoryCreationFailure() throws IOException { - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - // Mock memory factory failure for existing conversation - RuntimeException memoryException = new RuntimeException("Memory creation failed"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(memoryException); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("existing-memory"), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "existing-memory"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "existing-parent"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - mlAgentExecutor.execute(agentMLInput, agentActionListener); + public void testExecuteWithNullParameters() { + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().build(); + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, dataset); - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertEquals(memoryException, exception); + try { + mlAgentExecutor.execute(agentInput, listener, channel); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("Agent input data can not be empty.", exception.getMessage()); + } } @Test - public void test_ExecuteAgent_SyncMode() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - Map params = new HashMap<>(); - params.put(QUESTION, "test question"); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parentInteractionId"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - agentMLInput.setIsAsync(false); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - ModelTensorOutput output = (ModelTensorOutput) objectCaptor.getValue(); - Assert.assertEquals(1, output.getMlModelOutputs().size()); - Assert.assertEquals(1, output.getMlModelOutputs().get(0).getMlModelTensors().size()); - Assert.assertEquals(modelTensor, output.getMlModelOutputs().get(0).getMlModelTensors().get(0)); + public void testExecuteWithMultiTenancyEnabledButNoTenantId() { + when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true); + mlAgentExecutor.onMultiTenancyEnabledChanged(true); + + Map parameters = Collections.singletonMap("question", "test question"); + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder() + .parameters(parameters) + .build(); + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, dataset); + + try { + mlAgentExecutor.execute(agentInput, listener, channel); + fail("Expected OpenSearchStatusException"); + } catch (OpenSearchStatusException exception) { + assertEquals("You don't have permission to access this resource", exception.getMessage()); + assertEquals(RestStatus.FORBIDDEN, exception.status()); + } } @Test - public void test_ExecuteAgent_AsyncMode() throws IOException { - ModelTensor modelTensor = ModelTensor.builder().name("response").dataAsMap(ImmutableMap.of("test_key", "test_value")).build(); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(modelTensor); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); + public void testExecuteWithAgentIndexNotFound() { + Map parameters = Collections.singletonMap("question", "test question"); + RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + AgentMLInput agentInput = new AgentMLInput("test-agent", null, FunctionName.AGENT, dataset); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + // Since we can't mock static methods easily, we'll test a different scenario + // This test would need the actual MLIndicesHandler behavior + mlAgentExecutor.execute(agentInput, listener, channel); - Map params = new HashMap<>(); - params.put(QUESTION, "test question"); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parentInteractionId"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - agentMLInput.setIsAsync(true); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertEquals("task-123", output.getTaskId()); - Assert.assertEquals("RUNNING", output.getStatus()); + // Verify that the listener was called (the actual behavior will depend on the implementation) + verify(listener, timeout(5000).atLeastOnce()).onFailure(any()); } @Test - public void test_UpdateInteractionWithFailure() throws IOException { - RuntimeException testException = new RuntimeException("Test failure message"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(testException); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent-id"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - ArgumentCaptor> updateCaptor = ArgumentCaptor.forClass(Map.class); - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent-id"), updateCaptor.capture(), Mockito.any()); - Map updateContent = updateCaptor.getValue(); - Assert.assertEquals("Agent execution failed: Test failure message", updateContent.get("response")); + public void testGetAgentRunnerWithFlowAgent() { + MLAgent agent = createTestAgent(MLAgentType.FLOW.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLFlowAgentRunner); } @Test - public void test_ConversationMemoryCreationFailure() throws IOException { - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", true, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - RuntimeException memoryException = new RuntimeException("Failed to read conversation memory"); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onFailure(memoryException); - return null; - }).when(mockMemoryFactory).create(Mockito.eq("test question"), Mockito.eq(null), Mockito.any(), Mockito.any()); - - Map params = new HashMap<>(); - params.put(QUESTION, "test question"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(agentActionListener).onFailure(exceptionCaptor.capture()); - Exception exception = exceptionCaptor.getValue(); - Assert.assertEquals(memoryException, exception); + public void testGetAgentRunnerWithConversationalFlowAgent() { + MLAgent agent = createTestAgent(MLAgentType.CONVERSATIONAL_FLOW.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLConversationalFlowAgentRunner); } @Test - public void test_AsyncExecution_NullOutput() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(null); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertNotNull(output.getTaskId()); + public void testGetAgentRunnerWithConversationalAgent() { + MLAgent agent = createTestAgent(MLAgentType.CONVERSATIONAL.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLChatAgentRunner); } @Test - public void test_AsyncExecution_Failure() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Agent execution failed")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertNotNull(output.getTaskId()); + public void testGetAgentRunnerWithPlanExecuteAndReflectAgent() { + MLAgent agent = createTestAgent(MLAgentType.PLAN_EXECUTE_AND_REFLECT.name()); + MLAgentRunner runner = mlAgentExecutor.getAgentRunner(agent, null); + assertNotNull(runner); + assertTrue(runner instanceof MLPlanExecuteAndReflectAgentRunner); } @Test - public void test_UpdateInteractionFailure_LogLines() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Test failure")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); - return null; - }).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - Map params = new HashMap<>(); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent-id"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent-id"), Mockito.any(), Mockito.any()); - } - - @Test - public void test_UpdateInteractionFailure_ErrorCallback() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Test failure")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Update failed")); - return null; - }).when(memoryManager).updateInteraction(Mockito.any(), Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + public void testGetAgentRunnerWithUnsupportedAgentType() { + // Create a mock MLAgent instead of using the constructor that validates + MLAgent agent = mock(MLAgent.class); + when(agent.getType()).thenReturn("UNSUPPORTED_TYPE"); - Map params = new HashMap<>(); - params.put(MEMORY_ID, "memoryId"); - params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "test-parent-id"); - RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(params).build(); - AgentMLInput agentMLInput = new AgentMLInput("test", null, FunctionName.AGENT, dataset); - - mlAgentExecutor.execute(agentMLInput, agentActionListener); - - Mockito.verify(memoryManager).updateInteraction(Mockito.eq("test-parent-id"), Mockito.any(), Mockito.any()); + try { + mlAgentExecutor.getAgentRunner(agent, null); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException exception) { + assertEquals("Wrong Agent type", exception.getMessage()); + } } @Test - public void test_AsyncTaskUpdate_SuccessCallback() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse("success"); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); + public void testProcessOutputWithModelTensorOutput() throws Exception { + ModelTensorOutput output = mock(ModelTensorOutput.class); + when(output.getMlModelOutputs()).thenReturn(Collections.emptyList()); - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + List modelTensors = new java.util.ArrayList<>(); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); + mlAgentExecutor.processOutput(output, modelTensors); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + verify(output).getMlModelOutputs(); } @Test - public void test_AsyncTaskUpdate_FailureCallback() throws IOException { - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Agent failed")); - return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); + public void testProcessOutputWithString() throws Exception { + String output = "test response"; + List modelTensors = new java.util.ArrayList<>(); - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); + mlAgentExecutor.processOutput(output, modelTensors); - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); + assertEquals(1, modelTensors.size()); + assertEquals("response", modelTensors.get(0).getName()); + assertEquals("test response", modelTensors.get(0).getResult()); } - @Test - public void test_AgentRunnerException() throws IOException { - // Reset mocks to ensure clean state - Mockito.reset(mlAgentRunner); - - RuntimeException testException = new RuntimeException("Agent runner threw exception"); - Mockito.doThrow(testException).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); - - GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(agentGetResponse); - return null; - }).when(client).get(Mockito.any(GetRequest.class), Mockito.any(ActionListener.class)); - - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(memory); - return null; - }).when(mockMemoryFactory).create(Mockito.eq(null), Mockito.eq("memoryId"), Mockito.any(), Mockito.any()); - - indexResponse = new IndexResponse(new ShardId(ML_TASK_INDEX, "_na_", 0), "task-123", 1, 0, 2, true); - Mockito.doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(indexResponse); - return null; - }).when(mlAgentExecutor).indexMLTask(Mockito.any(), Mockito.any()); - - Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); - - AgentMLInput input = getAgentMLInput(); - input.setIsAsync(true); - mlAgentExecutor.execute(input, agentActionListener); - - Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); - MLTaskOutput output = (MLTaskOutput) objectCaptor.getValue(); - Assert.assertEquals("task-123", output.getTaskId()); + private MLAgent createTestAgent(String type) { + return MLAgent + .builder() + .name("test-agent") + .type(type) + .description("Test agent") + .llm(LLMSpec.builder().modelId("test-model").parameters(Collections.emptyMap()).build()) + .tools(Collections.emptyList()) + .parameters(Collections.emptyMap()) + .memory(null) + .createdTime(Instant.now()) + .lastUpdateTime(Instant.now()) + .appType("test-app") + .build(); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 115cc8bf03..1855c6fbd0 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -710,7 +710,7 @@ public void testToolParameters() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(15, ((Map) argumentCaptor.getValue()).size()); + assertEquals(16, ((Map) argumentCaptor.getValue()).size()); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); @@ -738,7 +738,7 @@ public void testToolUseOriginalInput() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); assertEquals("raw input", ((Map) argumentCaptor.getValue()).get("input")); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); @@ -804,7 +804,7 @@ public void testToolConfig() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be "config_value". assertEquals("config_value", ((Map) argumentCaptor.getValue()).get("input")); @@ -834,7 +834,7 @@ public void testToolConfigWithInputPlaceholder() { // Verify the size of parameters passed in the tool run method. ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); verify(firstTool).run((Map) argumentCaptor.capture(), any()); - assertEquals(16, ((Map) argumentCaptor.getValue()).size()); + assertEquals(17, ((Map) argumentCaptor.getValue()).size()); // The value of input should be replaced with the value associated with the key "key2" of the first tool. assertEquals("value2", ((Map) argumentCaptor.getValue()).get("input")); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index a8be11c5f1..985661a9c9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -135,6 +135,7 @@ public void setup() { // memory mlMemorySpec = new MLMemorySpec(ConversationIndexMemory.TYPE, "uuid", 10); + when(memoryMap.get(ConversationIndexMemory.TYPE)).thenReturn(memoryFactory); when(memoryMap.get(anyString())).thenReturn(memoryFactory); when(conversationIndexMemory.getConversationId()).thenReturn("test_memory_id"); when(conversationIndexMemory.getMemoryManager()).thenReturn(mlMemoryManager); @@ -170,7 +171,8 @@ public void setup() { toolFactories, memoryMap, sdkClient, - encryptor + encryptor, + null ); // Setup tools diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManagerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManagerTest.java new file mode 100644 index 0000000000..692c0ebf7c --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManagerTest.java @@ -0,0 +1,234 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; + +/** + * Unit tests for SlidingWindowManager. + */ +public class SlidingWindowManagerTest { + + private SlidingWindowManager manager; + private ContextManagerContext context; + + @Before + public void setUp() { + manager = new SlidingWindowManager(); + context = ContextManagerContext.builder().toolInteractions(new ArrayList<>()).parameters(new HashMap<>()).build(); + } + + @Test + public void testGetType() { + Assert.assertEquals("SlidingWindowManager", manager.getType()); + } + + @Test + public void testInitializeWithDefaults() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Should initialize with default values without throwing exceptions + Assert.assertNotNull(manager); + } + + @Test + public void testInitializeWithCustomConfig() { + Map config = new HashMap<>(); + config.put("max_messages", 10); + + manager.initialize(config); + + // Should initialize without throwing exceptions + Assert.assertNotNull(manager); + } + + @Test + public void testShouldActivateWithNoRules() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Should always activate when no rules are defined + Assert.assertTrue(manager.shouldActivate(context)); + } + + @Test + public void testExecuteWithEmptyToolInteractions() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Should handle empty tool interactions gracefully + manager.execute(context); + + Assert.assertTrue(context.getToolInteractions().isEmpty()); + } + + @Test + public void testExecuteWithSmallToolInteractions() { + Map config = new HashMap<>(); + config.put("max_messages", 10); + manager.initialize(config); + + // Add fewer interactions than the limit + addToolInteractionsToContext(5); + int originalSize = context.getToolInteractions().size(); + + manager.execute(context); + + // Tool interactions should remain unchanged + Assert.assertEquals(originalSize, context.getToolInteractions().size()); + } + + @Test + public void testExecuteWithLargeToolInteractions() { + Map config = new HashMap<>(); + config.put("max_messages", 5); + manager.initialize(config); + + // Add more interactions than the limit + addToolInteractionsToContext(10); + + manager.execute(context); + + // Tool interactions should be truncated to the limit + Assert.assertEquals(5, context.getToolInteractions().size()); + + // Parameters should be updated with truncated interactions + String interactionsParam = (String) context.getParameters().get("_interactions"); + Assert.assertNotNull(interactionsParam); + + // Should contain only the last 5 interactions + String[] interactions = interactionsParam.substring(2).split(", "); // Remove ", " prefix + Assert.assertEquals(5, interactions.length); + + // Should keep the most recent interactions (6-10) + for (int i = 0; i < interactions.length; i++) { + String expected = "Tool output " + (6 + i); + Assert.assertEquals(expected, interactions[i]); + } + + // Verify toolInteractions also contain the most recent ones + for (int i = 0; i < context.getToolInteractions().size(); i++) { + String expected = "Tool output " + (6 + i); + String actual = context.getToolInteractions().get(i); + Assert.assertEquals(expected, actual); + } + } + + @Test + public void testExecuteKeepsMostRecentInteractions() { + Map config = new HashMap<>(); + config.put("max_messages", 3); + manager.initialize(config); + + // Add interactions with identifiable content + addToolInteractionsToContext(7); + + manager.execute(context); + + // Should keep the last 3 interactions (5, 6, 7) + String interactionsParam = (String) context.getParameters().get("_interactions"); + String[] interactions = interactionsParam.substring(2).split(", "); + Assert.assertEquals(3, interactions.length); + Assert.assertEquals("Tool output 5", interactions[0]); + Assert.assertEquals("Tool output 6", interactions[1]); + Assert.assertEquals("Tool output 7", interactions[2]); + } + + @Test + public void testExecuteWithExactLimit() { + Map config = new HashMap<>(); + config.put("max_messages", 5); + manager.initialize(config); + + // Add exactly the limit number of interactions + addToolInteractionsToContext(5); + + manager.execute(context); + + // Parameters should not be updated since no truncation needed + Assert.assertNull(context.getParameters().get("_interactions")); + } + + @Test + public void testExecuteWithNullToolInteractions() { + Map config = new HashMap<>(); + manager.initialize(config); + + context.setToolInteractions(null); + + // Should handle null tool interactions gracefully + manager.execute(context); + + // Should not throw exception + Assert.assertNull(context.getToolInteractions()); + } + + @Test + public void testExecuteWithNonStringOutputs() { + Map config = new HashMap<>(); + config.put("max_messages", 1); // Set to 1 to force truncation + manager.initialize(config); + + // Add tool interactions as strings + context.getToolInteractions().add("123"); // Integer as string + context.getToolInteractions().add("String output"); // String output + + manager.execute(context); + + // Should process all string interactions and set _interactions parameter + Assert.assertNotNull(context.getParameters().get("_interactions")); + Assert.assertEquals(1, context.getToolInteractions().size()); // Should keep only 1 + } + + @Test + public void testInvalidMaxMessagesConfig() { + Map config = new HashMap<>(); + config.put("max_messages", "invalid_number"); + + // Should handle invalid config gracefully and use default + manager.initialize(config); + + Assert.assertNotNull(manager); + } + + @Test + public void testExecuteWithNullParameters() { + Map config = new HashMap<>(); + config.put("max_messages", 3); + manager.initialize(config); + + // Set parameters to null + context.setParameters(null); + addToolInteractionsToContext(5); + + manager.execute(context); + + // Should create new parameters map and update it + Assert.assertNotNull(context.getParameters()); + String interactionsParam = (String) context.getParameters().get("_interactions"); + Assert.assertNotNull(interactionsParam); + + String[] interactions = interactionsParam.substring(2).split(", "); + Assert.assertEquals(3, interactions.length); + } + + /** + * Helper method to add tool interactions to the context. + */ + private void addToolInteractionsToContext(int count) { + for (int i = 1; i <= count; i++) { + context.getToolInteractions().add("Tool output " + i); + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java new file mode 100644 index 0000000000..6f769d2645 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManagerTest.java @@ -0,0 +1,326 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.contextmanager; + +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.contextmanager.ContextManagerContext; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.transport.client.Client; + +/** + * Unit tests for SummarizationManager. + */ +public class SummarizationManagerTest { + + @Mock + private Client client; + + private SummarizationManager manager; + private ContextManagerContext context; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + manager = new SummarizationManager(client); + context = ContextManagerContext.builder().toolInteractions(new ArrayList<>()).parameters(new HashMap<>()).build(); + } + + @Test + public void testGetType() { + Assert.assertEquals("SummarizationManager", manager.getType()); + } + + @Test + public void testInitializeWithDefaults() { + Map config = new HashMap<>(); + manager.initialize(config); + + Assert.assertEquals(0.3, manager.summaryRatio, 0.001); + Assert.assertEquals(10, manager.preserveRecentMessages); + } + + @Test + public void testInitializeWithCustomConfig() { + Map config = new HashMap<>(); + config.put("summary_ratio", 0.5); + config.put("preserve_recent_messages", 5); + config.put("summarization_model_id", "test-model"); + config.put("summarization_system_prompt", "Custom prompt"); + + manager.initialize(config); + + Assert.assertEquals(0.5, manager.summaryRatio, 0.001); + Assert.assertEquals(5, manager.preserveRecentMessages); + Assert.assertEquals("test-model", manager.summarizationModelId); + Assert.assertEquals("Custom prompt", manager.summarizationSystemPrompt); + } + + @Test + public void testInitializeWithInvalidSummaryRatio() { + Map config = new HashMap<>(); + config.put("summary_ratio", 0.9); // Invalid - too high + + manager.initialize(config); + + // Should use default value + Assert.assertEquals(0.3, manager.summaryRatio, 0.001); + } + + @Test + public void testShouldActivateWithNoRules() { + Map config = new HashMap<>(); + manager.initialize(config); + + Assert.assertTrue(manager.shouldActivate(context)); + } + + @Test + public void testExecuteWithEmptyToolInteractions() { + Map config = new HashMap<>(); + manager.initialize(config); + + manager.execute(context); + + Assert.assertTrue(context.getToolInteractions().isEmpty()); + } + + @Test + public void testExecuteWithInsufficientMessages() { + Map config = new HashMap<>(); + config.put("preserve_recent_messages", 10); + manager.initialize(config); + + // Add only 5 interactions - not enough to summarize + addToolInteractionsToContext(5); + + manager.execute(context); + + // Should remain unchanged + Assert.assertEquals(5, context.getToolInteractions().size()); + } + + @Test + public void testExecuteWithNoModelId() { + Map config = new HashMap<>(); + manager.initialize(config); + + addToolInteractionsToContext(20); + + manager.execute(context); + + // Should remain unchanged due to missing model ID + Assert.assertEquals(20, context.getToolInteractions().size()); + } + + @Test + public void testExecuteWithNonStringOutputs() { + Map config = new HashMap<>(); + manager.initialize(config); + + // Add tool interactions as strings + context.getToolInteractions().add("123"); // Integer as string + context.getToolInteractions().add("String output"); // String output + + manager.execute(context); + + // Should handle gracefully - only 2 string interactions, not enough to summarize + Assert.assertEquals(2, context.getToolInteractions().size()); + } + + @Test + public void testProcessSummarizationResult() { + Map config = new HashMap<>(); + manager.initialize(config); + + addToolInteractionsToContext(10); + List remainingMessages = List.of("Message 6", "Message 7", "Message 8", "Message 9", "Message 10"); + + manager.processSummarizationResult(context, "Test summary", 5, remainingMessages, context.getToolInteractions()); + + // Should have 1 summary + 5 remaining = 6 total + Assert.assertEquals(6, context.getToolInteractions().size()); + + // First should be summary + String firstOutput = context.getToolInteractions().get(0); + Assert.assertTrue(firstOutput.contains("Test summary")); + } + + @Test + public void testExtractSummaryFromResponseWithLLMResponseFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, "$.choices[0].message.content"); + context.setParameters(parameters); + + // Create mock response with OpenAI-style structure + Map responseData = new HashMap<>(); + Map choice = new HashMap<>(); + Map message = new HashMap<>(); + message.put("content", "This is the extracted summary content"); + choice.put("message", message); + responseData.put("choices", List.of(choice)); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("This is the extracted summary content", result); + } + + @Test + public void testExtractSummaryFromResponseWithBedrockResponseFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with Bedrock-style LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text"); + context.setParameters(parameters); + + // Create mock response with Bedrock-style structure + Map responseData = new HashMap<>(); + Map output = new HashMap<>(); + Map message = new HashMap<>(); + Map content = new HashMap<>(); + content.put("text", "Bedrock extracted summary"); + message.put("content", List.of(content)); + output.put("message", message); + responseData.put("output", output); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("Bedrock extracted summary", result); + } + + @Test + public void testExtractSummaryFromResponseWithInvalidFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with invalid LLM_RESPONSE_FILTER path + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, "$.invalid.path"); + context.setParameters(parameters); + + // Create mock response with simple structure + Map responseData = new HashMap<>(); + responseData.put("response", "Fallback summary content"); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + // Should fall back to default parsing + Assert.assertEquals("Fallback summary content", result); + } + + @Test + public void testExtractSummaryFromResponseWithoutFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Context without LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + context.setParameters(parameters); + + // Create mock response with simple structure + Map responseData = new HashMap<>(); + responseData.put("response", "Default parsed summary"); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("Default parsed summary", result); + } + + @Test + public void testExtractSummaryFromResponseWithEmptyFilter() throws Exception { + Map config = new HashMap<>(); + manager.initialize(config); + + // Set up context with empty LLM_RESPONSE_FILTER + Map parameters = new HashMap<>(); + parameters.put(LLM_RESPONSE_FILTER, ""); + context.setParameters(parameters); + + // Create mock response + Map responseData = new HashMap<>(); + responseData.put("response", "Empty filter fallback"); + + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); + + // Use reflection to access the private method + java.lang.reflect.Method extractMethod = SummarizationManager.class + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); + extractMethod.setAccessible(true); + + String result = (String) extractMethod.invoke(manager, mockResponse, context); + + Assert.assertEquals("Empty filter fallback", result); + } + + /** + * Helper method to create a mock MLTaskResponse with the given data. + */ + private MLTaskResponse createMockMLTaskResponse(Map responseData) { + ModelTensor tensor = ModelTensor.builder().dataAsMap(responseData).build(); + + ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build(); + + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); + + return MLTaskResponse.builder().output(output).build(); + } + + /** + * Helper method to add tool interactions to the context. + */ + private void addToolInteractionsToContext(int count) { + for (int i = 1; i <= count; i++) { + context.getToolInteractions().add("Tool output " + i); + } + } +} diff --git a/plugin/build.gradle b/plugin/build.gradle index 01a01e1358..6addfac882 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -648,6 +648,16 @@ configurations.all { resolutionStrategy.force "jakarta.json:jakarta.json-api:2.1.3" resolutionStrategy.force "org.opensearch:opensearch:${opensearch_version}" resolutionStrategy.force "org.bouncycastle:bcprov-jdk18on:1.78.1" + // Force consistent Netty versions to resolve conflicts + resolutionStrategy.force 'io.netty:netty-codec-http:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-codec-http2:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-codec:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-transport:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-common:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-buffer:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-handler:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-resolver:4.1.124.Final' + resolutionStrategy.force 'io.netty:netty-transport-native-unix-common:4.1.124.Final' resolutionStrategy.force 'io.projectreactor:reactor-core:3.7.0' resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib-jdk8:1.9.10" resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib:1.9.23" diff --git a/plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java b/plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java new file mode 100644 index 0000000000..dc9ea439d8 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java @@ -0,0 +1,261 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agent; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; + +import lombok.extern.log4j.Log4j2; + +/** + * Validator for ML Agent registration that performs advanced validation + * requiring service dependencies. + * This validator handles validation that cannot be performed in the request + * object itself, + * such as template existence checking. + */ +@Log4j2 +public class MLAgentRegistrationValidator { + + private final ContextManagementTemplateService contextManagementTemplateService; + + public MLAgentRegistrationValidator(ContextManagementTemplateService contextManagementTemplateService) { + this.contextManagementTemplateService = contextManagementTemplateService; + } + + /** + * Validates an ML agent for registration, performing all necessary validation checks. + * This is the main validation entry point that orchestrates all validation steps. + * + * @param agent the ML agent to validate + * @param listener callback for validation result - onResponse(true) if valid, onFailure with exception if not + */ + public void validateAgentForRegistration(MLAgent agent, ActionListener listener) { + try { + log.debug("Starting agent registration validation for agent: {}", agent.getName()); + + // First, perform basic context management configuration validation + String configError = validateContextManagementConfiguration(agent); + if (configError != null) { + log.error("Agent registration validation failed - configuration error: {}", configError); + listener.onFailure(new IllegalArgumentException(configError)); + return; + } + + // If agent has a context management template reference, validate template access + if (agent.getContextManagementName() != null) { + validateContextManagementTemplateAccess(agent.getContextManagementName(), ActionListener.wrap(templateAccessValid -> { + log.debug("Agent registration validation completed successfully for agent: {}", agent.getName()); + listener.onResponse(true); + }, templateAccessError -> { + log.error("Agent registration validation failed - template access error: {}", templateAccessError.getMessage()); + listener.onFailure(templateAccessError); + })); + } else { + // No template reference, validation is complete + log.debug("Agent registration validation completed successfully for agent: {}", agent.getName()); + listener.onResponse(true); + } + } catch (Exception e) { + log.error("Unexpected error during agent registration validation", e); + listener.onFailure(new IllegalArgumentException("Agent validation failed: " + e.getMessage())); + } + } + + /** + * Validates context management template access (following connector access validation pattern). + * This method checks if the template exists and if the user has access to it. + * + * @param templateName the context management template name to validate + * @param listener callback for validation result - onResponse(true) if accessible, onFailure with exception if not + */ + public void validateContextManagementTemplateAccess(String templateName, ActionListener listener) { + try { + log.debug("Validating context management template access: {}", templateName); + + contextManagementTemplateService.getTemplate(templateName, ActionListener.wrap(template -> { + log.debug("Context management template access validation passed: {}", templateName); + listener.onResponse(true); + }, exception -> { + log.error("Context management template access validation failed: {}", templateName, exception); + if (exception instanceof MLResourceNotFoundException) { + listener.onFailure(new IllegalArgumentException("Context management template not found: " + templateName)); + } else { + listener + .onFailure( + new IllegalArgumentException("Failed to validate context management template: " + exception.getMessage()) + ); + } + })); + } catch (Exception e) { + log.error("Unexpected error during context management template access validation", e); + listener.onFailure(new IllegalArgumentException("Context management template validation failed: " + e.getMessage())); + } + } + + /** + * Validates context management configuration structure and requirements. + * This method performs comprehensive validation of context management settings. + * + * @param agent the ML agent to validate + * @return validation error message if invalid, null if valid + */ + public String validateContextManagementConfiguration(MLAgent agent) { + // Check for conflicting configuration (both name and inline config specified) + if (agent.getContextManagementName() != null && agent.getContextManagement() != null) { + return "Cannot specify both context_management_name and context_management"; + } + + // Validate context management template name if specified + if (agent.getContextManagementName() != null) { + String templateNameError = validateContextManagementTemplateName(agent.getContextManagementName()); + if (templateNameError != null) { + return templateNameError; + } + } + + // Validate inline context management configuration if specified + if (agent.getContextManagement() != null) { + String inlineConfigError = validateInlineContextManagementConfiguration(agent.getContextManagement()); + if (inlineConfigError != null) { + return inlineConfigError; + } + } + + return null; // Valid + } + + /** + * Validates the context management template name format and basic requirements. + * + * @param templateName the template name to validate + * @return validation error message if invalid, null if valid + */ + private String validateContextManagementTemplateName(String templateName) { + if (templateName == null || templateName.trim().isEmpty()) { + return "Context management template name cannot be null or empty"; + } + + if (templateName.length() > 256) { + return "Context management template name cannot exceed 256 characters"; + } + + if (!templateName.matches("^[a-zA-Z0-9_\\-\\.]+$")) { + return "Context management template name can only contain letters, numbers, underscores, hyphens, and dots"; + } + + return null; // Valid + } + + /** + * Validates the inline context management configuration structure and content. + * + * @param contextManagement the context management configuration to validate + * @return validation error message if invalid, null if valid + */ + private String validateInlineContextManagementConfiguration( + org.opensearch.ml.common.contextmanager.ContextManagementTemplate contextManagement + ) { + // Use the built-in validation from ContextManagementTemplate + if (!contextManagement.isValid()) { + return "Invalid context management configuration: configuration must have a name and at least one hook with valid context manager configurations"; + } + + // Additional validation for specific requirements + if (contextManagement.getName() == null || contextManagement.getName().trim().isEmpty()) { + return "Context management configuration name cannot be null or empty"; + } + + if (contextManagement.getHooks() == null || contextManagement.getHooks().isEmpty()) { + return "Context management configuration must define at least one hook"; + } + + // Validate hook names and configurations + return validateContextManagementHooks(contextManagement.getHooks()); + } + + /** + * Validates context management hooks configuration. + * + * @param hooks the hooks configuration to validate + * @return validation error message if invalid, null if valid + */ + private String validateContextManagementHooks( + java.util.Map> hooks + ) { + // Define valid hook names + java.util.Set validHookNames = java.util.Set + .of("PRE_TOOL", "POST_TOOL", "PRE_LLM", "POST_LLM", "PRE_EXECUTION", "POST_EXECUTION"); + + for (java.util.Map.Entry> entry : hooks + .entrySet()) { + String hookName = entry.getKey(); + java.util.List configs = entry.getValue(); + + // Validate hook name + if (!validHookNames.contains(hookName)) { + return "Invalid hook name: " + hookName + ". Valid hook names are: " + validHookNames; + } + + // Validate hook configurations + if (configs == null || configs.isEmpty()) { + return "Hook " + hookName + " must have at least one context manager configuration"; + } + + for (int i = 0; i < configs.size(); i++) { + org.opensearch.ml.common.contextmanager.ContextManagerConfig config = configs.get(i); + if (!config.isValid()) { + return "Invalid context manager configuration at index " + + i + + " in hook " + + hookName + + ": type cannot be null or empty"; + } + + // Validate context manager type + if (config.getType() != null) { + String typeError = validateContextManagerType(config.getType(), hookName, i); + if (typeError != null) { + return typeError; + } + } + } + } + + return null; // Valid + } + + /** + * Validates context manager type for known types. + * + * @param type the context manager type to validate + * @param hookName the hook name for error reporting + * @param index the configuration index for error reporting + * @return validation error message if invalid, null if valid + */ + private String validateContextManagerType(String type, String hookName, int index) { + // Define known context manager types + java.util.Set knownTypes = java.util.Set + .of("ToolsOutputTruncateManager", "SummarizationManager", "MemoryManager", "ConversationManager"); + + // For now, we'll allow unknown types to provide flexibility for future context + // manager types + // This provides extensibility while still validating known ones + if (!knownTypes.contains(type)) { + log + .debug( + "Unknown context manager type '{}' in hook '{}' at index {}. This may be a custom or future type.", + type, + hookName, + index + ); + } + + return null; // Valid - we allow unknown types for extensibility + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java index f4a4b8ff0e..883ab6771a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java @@ -26,6 +26,8 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.agent.MLAgentRegistrationValidator; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.AgentModelService; import org.opensearch.ml.common.agent.LLMSpec; @@ -62,6 +64,7 @@ public class TransportRegisterAgentAction extends HandledTransportAction listener) { + // Validate context management configuration (following connector pattern) + if (agent.hasContextManagementTemplate()) { + // Validate context management template access (similar to connector access validation) + String templateName = agent.getContextManagementTemplateName(); + agentRegistrationValidator.validateContextManagementTemplateAccess(templateName, ActionListener.wrap(hasAccess -> { + if (Boolean.TRUE.equals(hasAccess)) { + continueAgentRegistration(agent, listener); + } else { + listener + .onFailure( + new IllegalArgumentException( + "You don't have permission to use the context management template provided, template name: " + templateName + ) + ); + } + }, e -> { + log.error("You don't have permission to use the context management template provided, template name: {}", templateName, e); + listener.onFailure(e); + })); + } else if (agent.getInlineContextManagement() != null) { + // Validate inline context management configuration only if it exists (similar to inline connector validation) + validateInlineContextManagement(agent); + continueAgentRegistration(agent, listener); + } else { + // No context management configuration - that's fine, continue with registration + continueAgentRegistration(agent, listener); + } + } + + private void validateInlineContextManagement(MLAgent agent) { + if (agent.getInlineContextManagement() == null) { + log + .error( + "You must provide context management content when creating an agent without providing context management template name!" + ); + throw new IllegalArgumentException( + "You must provide context management content when creating an agent without context management template name!" + ); + } + + // Validate inline context management configuration structure + if (!agent.getInlineContextManagement().isValid()) { + log + .error( + "Invalid context management configuration: configuration must have a name and at least one hook with valid context manager configurations" + ); + throw new IllegalArgumentException( + "Invalid context management configuration: configuration must have a name and at least one hook with valid context manager configurations" + ); + } + } + + private void continueAgentRegistration(MLAgent agent, ActionListener listener) { String mcpConnectorConfigJSON = (agent.getParameters() != null) ? agent.getParameters().get(MCP_CONNECTORS_FIELD) : null; if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) { // MCP connector provided as tools but MCP feature is disabled, so abort. diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtils.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtils.java new file mode 100644 index 0000000000..8ecc0a0e2f --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtils.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Utility class for managing context management template indices. + * Handles index creation, mapping definition, and settings configuration. + */ +@Log4j2 +public class ContextManagementIndexUtils { + + public static final String CONTEXT_MANAGEMENT_TEMPLATES_INDEX = "ml_context_management_templates"; + + private final Client client; + private final ClusterService clusterService; + + public ContextManagementIndexUtils(Client client, ClusterService clusterService) { + this.client = client; + this.clusterService = clusterService; + } + + /** + * Check if the context management templates index exists + * @return true if the index exists, false otherwise + */ + public boolean doesIndexExist() { + return clusterService.state().metadata().hasIndex(CONTEXT_MANAGEMENT_TEMPLATES_INDEX); + } + + /** + * Create the context management templates index if it doesn't exist + * @param listener ActionListener for the response + */ + public void createIndexIfNotExists(ActionListener listener) { + if (doesIndexExist()) { + log.debug("Context management templates index already exists"); + listener.onResponse(true); + return; + } + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + CreateIndexRequest createIndexRequest = new CreateIndexRequest(CONTEXT_MANAGEMENT_TEMPLATES_INDEX).settings(getIndexSettings()); + + client.admin().indices().create(createIndexRequest, ActionListener.wrap(createIndexResponse -> { + log.info("Successfully created context management templates index"); + wrappedListener.onResponse(true); + }, exception -> { + if (exception instanceof org.opensearch.ResourceAlreadyExistsException) { + log.debug("Context management templates index already exists"); + wrappedListener.onResponse(true); + } else { + log.error("Failed to create context management templates index", exception); + wrappedListener.onFailure(exception); + } + })); + } catch (Exception e) { + log.error("Error creating context management templates index", e); + listener.onFailure(e); + } + } + + /** + * Get the index settings for context management templates + * @return Settings for the index + */ + private Settings getIndexSettings() { + return Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.auto_expand_replicas", "0-1") + .build(); + } + + /** + * Get the index name for context management templates + * @return The index name + */ + public static String getIndexName() { + return CONTEXT_MANAGEMENT_TEMPLATES_INDEX; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateService.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateService.java new file mode 100644 index 0000000000..d754375688 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateService.java @@ -0,0 +1,316 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; + +import java.time.Instant; +import java.util.List; + +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Service for managing context management templates in OpenSearch. + * Provides CRUD operations for storing and retrieving context management configurations. + */ +@Log4j2 +public class ContextManagementTemplateService { + + private final MLIndicesHandler mlIndicesHandler; + private final Client client; + private final ClusterService clusterService; + private final ContextManagementIndexUtils indexUtils; + + @Inject + public ContextManagementTemplateService(MLIndicesHandler mlIndicesHandler, Client client, ClusterService clusterService) { + this.mlIndicesHandler = mlIndicesHandler; + this.client = client; + this.clusterService = clusterService; + this.indexUtils = new ContextManagementIndexUtils(client, clusterService); + } + + /** + * Save a context management template to OpenSearch + * @param templateName The name of the template + * @param template The template to save + * @param listener ActionListener for the response + */ + public void saveTemplate(String templateName, ContextManagementTemplate template, ActionListener listener) { + try { + // Validate template + if (!template.isValid()) { + listener.onFailure(new IllegalArgumentException("Invalid context management template")); + return; + } + + User user = RestActionUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + // Set timestamps + Instant now = Instant.now(); + if (template.getCreatedTime() == null) { + template.setCreatedTime(now); + } + template.setLastModified(now); + + // Set created by if not already set + if (template.getCreatedBy() == null && user != null) { + template.setCreatedBy(user.getName()); + } + + // Ensure index exists first + indexUtils.createIndexIfNotExists(ActionListener.wrap(indexCreated -> { + // Check if template with same name already exists + validateUniqueTemplateName(template.getName(), ActionListener.wrap(exists -> { + if (exists) { + wrappedListener + .onFailure( + new IllegalArgumentException( + "A context management template with name '" + template.getName() + "' already exists" + ) + ); + return; + } + + // Create the index request with proper JSON serialization + IndexRequest indexRequest = new IndexRequest(ContextManagementIndexUtils.getIndexName()) + .id(template.getName()) + .source(template.toXContent(jsonXContent.contentBuilder(), ToXContentObject.EMPTY_PARAMS)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + // Execute the index operation + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + log.info("Context management template saved successfully: {}", template.getName()); + wrappedListener.onResponse(true); + }, exception -> { + log.error("Failed to save context management template: {}", template.getName(), exception); + wrappedListener.onFailure(exception); + })); + }, wrappedListener::onFailure)); + }, wrappedListener::onFailure)); + } + } catch (Exception e) { + log.error("Error saving context management template", e); + listener.onFailure(e); + } + } + + /** + * Get a context management template by name + * @param templateName The name of the template to retrieve + * @param listener ActionListener for the response + */ + public void getTemplate(String templateName, ActionListener listener) { + try { + if (Strings.isNullOrEmpty(templateName)) { + listener.onFailure(new IllegalArgumentException("Template name cannot be null or empty")); + return; + } + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + GetRequest getRequest = new GetRequest(ContextManagementIndexUtils.getIndexName(), templateName); + + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (!getResponse.isExists()) { + wrappedListener + .onFailure(new MLResourceNotFoundException("Context management template not found: " + templateName)); + return; + } + + try { + XContentParser parser = createXContentParserFromRegistry( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsBytesRef() + ); + ContextManagementTemplate template = ContextManagementTemplate.parse(parser); + wrappedListener.onResponse(template); + } catch (Exception e) { + log.error("Failed to parse context management template: {}", templateName, e); + wrappedListener.onFailure(e); + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + wrappedListener + .onFailure(new MLResourceNotFoundException("Context management template not found: " + templateName)); + } else { + log.error("Failed to get context management template: {}", templateName, exception); + wrappedListener.onFailure(exception); + } + })); + } + } catch (Exception e) { + log.error("Error getting context management template", e); + listener.onFailure(e); + } + } + + /** + * List all context management templates + * @param listener ActionListener for the response + */ + public void listTemplates(ActionListener> listener) { + listTemplates(0, 1000, listener); + } + + /** + * List context management templates with pagination + * @param from Starting index for pagination + * @param size Number of templates to return + * @param listener ActionListener for the response + */ + public void listTemplates(int from, int size, ActionListener> listener) { + try { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener> wrappedListener = ActionListener.runBefore(listener, context::restore); + + SearchRequest searchRequest = new SearchRequest(ContextManagementIndexUtils.getIndexName()); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(new MatchAllQueryBuilder()).from(from).size(size); + searchRequest.source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + try { + List templates = new java.util.ArrayList<>(); + for (SearchHit hit : searchResponse.getHits().getHits()) { + XContentParser parser = createXContentParserFromRegistry( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + hit.getSourceRef() + ); + ContextManagementTemplate template = ContextManagementTemplate.parse(parser); + templates.add(template); + } + wrappedListener.onResponse(templates); + } catch (Exception e) { + log.error("Failed to parse context management templates", e); + wrappedListener.onFailure(e); + } + }, exception -> { + if (exception instanceof IndexNotFoundException) { + // Return empty list if index doesn't exist + wrappedListener.onResponse(new java.util.ArrayList<>()); + } else { + log.error("Failed to list context management templates", exception); + wrappedListener.onFailure(exception); + } + })); + } + } catch (Exception e) { + log.error("Error listing context management templates", e); + listener.onFailure(e); + } + } + + /** + * Delete a context management template by name + * @param templateName The name of the template to delete + * @param listener ActionListener for the response + */ + public void deleteTemplate(String templateName, ActionListener listener) { + try { + if (Strings.isNullOrEmpty(templateName)) { + listener.onFailure(new IllegalArgumentException("Template name cannot be null or empty")); + return; + } + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); + + DeleteRequest deleteRequest = new DeleteRequest(ContextManagementIndexUtils.getIndexName(), templateName) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + client.delete(deleteRequest, ActionListener.wrap(deleteResponse -> { + boolean deleted = deleteResponse.getResult() == DeleteResponse.Result.DELETED; + if (deleted) { + log.info("Context management template deleted successfully: {}", templateName); + } else { + log.warn("Context management template not found for deletion: {}", templateName); + } + wrappedListener.onResponse(deleted); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + wrappedListener.onResponse(false); + } else { + log.error("Failed to delete context management template: {}", templateName, exception); + wrappedListener.onFailure(exception); + } + })); + } + } catch (Exception e) { + log.error("Error deleting context management template", e); + listener.onFailure(e); + } + } + + /** + * Validate that a template name is unique + * @param templateName The template name to check + * @param listener ActionListener for the response (true if exists, false if unique) + */ + private void validateUniqueTemplateName(String templateName, ActionListener listener) { + try { + SearchRequest searchRequest = new SearchRequest(ContextManagementIndexUtils.getIndexName()); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(new TermQueryBuilder("_id", templateName)).size(1); + searchRequest.source(searchSourceBuilder); + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + boolean exists = searchResponse.getHits().getTotalHits() != null && searchResponse.getHits().getTotalHits().value() > 0; + listener.onResponse(exists); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + // Index doesn't exist, so template name is unique + listener.onResponse(false); + } else { + listener.onFailure(exception); + } + })); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Create XContentParser from registry - utility method + */ + private XContentParser createXContentParserFromRegistry( + NamedXContentRegistry xContentRegistry, + LoggingDeprecationHandler deprecationHandler, + org.opensearch.core.common.bytes.BytesReference bytesReference + ) throws java.io.IOException { + return MediaTypeRegistry.JSON.xContent().createParser(xContentRegistry, deprecationHandler, bytesReference.streamInput()); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactory.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactory.java new file mode 100644 index 0000000000..86b26b4d6b --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactory.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import java.util.Map; + +import org.opensearch.common.inject.Inject; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.engine.algorithms.contextmanager.SlidingWindowManager; +import org.opensearch.ml.engine.algorithms.contextmanager.SummarizationManager; +import org.opensearch.ml.engine.algorithms.contextmanager.ToolsOutputTruncateManager; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +/** + * Factory for creating context manager instances from configuration. + * This factory creates the appropriate context manager based on the type + * specified in the configuration and initializes it with the provided settings. + */ +@Log4j2 +public class ContextManagerFactory { + + private final ActivationRuleFactory activationRuleFactory; + private final Client client; + + @Inject + public ContextManagerFactory(ActivationRuleFactory activationRuleFactory, Client client) { + this.activationRuleFactory = activationRuleFactory; + this.client = client; + } + + /** + * Create a context manager instance from configuration + * @param config The context manager configuration + * @return The created context manager instance + * @throws IllegalArgumentException if the manager type is not supported + */ + public ContextManager createContextManager(ContextManagerConfig config) { + if (config == null || config.getType() == null) { + throw new IllegalArgumentException("Context manager configuration and type cannot be null"); + } + + String type = config.getType(); + Map managerConfig = config.getConfig(); + Map activationConfig = config.getActivation(); + + log.debug("Creating context manager of type: {}", type); + + ContextManager manager; + switch (type) { + case "ToolsOutputTruncateManager": + manager = createToolsOutputTruncateManager(managerConfig); + break; + case "SlidingWindowManager": + manager = createSlidingWindowManager(managerConfig); + break; + case "SummarizationManager": + manager = createSummarizationManager(managerConfig); + break; + default: + throw new IllegalArgumentException("Unsupported context manager type: " + type); + } + + // Initialize the manager with configuration + try { + // Merge activation and manager config for initialization + Map fullConfig = new java.util.HashMap<>(); + if (managerConfig != null) { + fullConfig.putAll(managerConfig); + } + if (activationConfig != null) { + fullConfig.put("activation", activationConfig); + } + + manager.initialize(fullConfig); + log.debug("Successfully created and initialized context manager: {}", type); + return manager; + } catch (Exception e) { + log.error("Failed to initialize context manager of type: {}", type, e); + throw new RuntimeException("Failed to initialize context manager: " + type, e); + } + } + + /** + * Create a ToolsOutputTruncateManager instance + */ + private ContextManager createToolsOutputTruncateManager(Map config) { + return new ToolsOutputTruncateManager(); + } + + /** + * Create a SlidingWindowManager instance + */ + private ContextManager createSlidingWindowManager(Map config) { + return new SlidingWindowManager(); + } + + /** + * Create a SummarizationManager instance + */ + private ContextManager createSummarizationManager(Map config) { + return new SummarizationManager(client); + } + + // Add more factory methods for other context manager types as they are implemented + + // private ContextManager createSummarizingManager(Map config) { + // return new SummarizingManager(); + // } + + // private ContextManager createSystemPromptAugmentationManager(Map config) { + // return new SystemPromptAugmentationManager(); + // } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportAction.java new file mode 100644 index 0000000000..d5377d1bf1 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportAction.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CreateContextManagementTemplateTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public CreateContextManagementTemplateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLCreateContextManagementTemplateAction.NAME, transportService, actionFilters, MLCreateContextManagementTemplateRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLCreateContextManagementTemplateRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.info("Creating context management template: {}", request.getTemplateName()); + + contextManagementTemplateService.saveTemplate(request.getTemplateName(), request.getTemplate(), ActionListener.wrap(success -> { + if (success) { + log.info("Successfully created context management template: {}", request.getTemplateName()); + listener.onResponse(new MLCreateContextManagementTemplateResponse(request.getTemplateName(), "created")); + } else { + log.error("Failed to create context management template: {}", request.getTemplateName()); + listener.onFailure(new RuntimeException("Failed to create context management template")); + } + }, exception -> { + log.error("Error creating context management template: {}", request.getTemplateName(), exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error creating context management template: {}", request.getTemplateName(), e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportAction.java new file mode 100644 index 0000000000..6c025bc927 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportAction.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class DeleteContextManagementTemplateTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public DeleteContextManagementTemplateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLDeleteContextManagementTemplateAction.NAME, transportService, actionFilters, MLDeleteContextManagementTemplateRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLDeleteContextManagementTemplateRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.info("Deleting context management template: {}", request.getTemplateName()); + + contextManagementTemplateService.deleteTemplate(request.getTemplateName(), ActionListener.wrap(success -> { + if (success) { + log.info("Successfully deleted context management template: {}", request.getTemplateName()); + listener.onResponse(new MLDeleteContextManagementTemplateResponse(request.getTemplateName(), "deleted")); + } else { + log.warn("Context management template not found for deletion: {}", request.getTemplateName()); + listener.onFailure(new RuntimeException("Context management template not found: " + request.getTemplateName())); + } + }, exception -> { + log.error("Error deleting context management template: {}", request.getTemplateName(), exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error deleting context management template: {}", request.getTemplateName(), e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportAction.java new file mode 100644 index 0000000000..011b6852c0 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportAction.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetContextManagementTemplateTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public GetContextManagementTemplateTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLGetContextManagementTemplateAction.NAME, transportService, actionFilters, MLGetContextManagementTemplateRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLGetContextManagementTemplateRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.debug("Getting context management template: {}", request.getTemplateName()); + + contextManagementTemplateService.getTemplate(request.getTemplateName(), ActionListener.wrap(template -> { + if (template != null) { + log.debug("Successfully retrieved context management template: {}", request.getTemplateName()); + listener.onResponse(new MLGetContextManagementTemplateResponse(template)); + } else { + log.warn("Context management template not found: {}", request.getTemplateName()); + listener.onFailure(new RuntimeException("Context management template not found: " + request.getTemplateName())); + } + }, exception -> { + log.error("Error getting context management template: {}", request.getTemplateName(), exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error getting context management template: {}", request.getTemplateName(), e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportAction.java new file mode 100644 index 0000000000..7667ac6cc6 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportAction.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesResponse; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ListContextManagementTemplatesTransportAction extends + HandledTransportAction { + + private final Client client; + private final ContextManagementTemplateService contextManagementTemplateService; + + @Inject + public ListContextManagementTemplatesTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ContextManagementTemplateService contextManagementTemplateService + ) { + super(MLListContextManagementTemplatesAction.NAME, transportService, actionFilters, MLListContextManagementTemplatesRequest::new); + this.client = client; + this.contextManagementTemplateService = contextManagementTemplateService; + } + + @Override + protected void doExecute( + Task task, + MLListContextManagementTemplatesRequest request, + ActionListener listener + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + log.debug("Listing context management templates from: {} size: {}", request.getFrom(), request.getSize()); + + contextManagementTemplateService.listTemplates(request.getFrom(), request.getSize(), ActionListener.wrap(templates -> { + log.debug("Successfully retrieved {} context management templates", templates.size()); + // For now, return the size as total. In a real implementation, you'd get the actual total count + listener.onResponse(new MLListContextManagementTemplatesResponse(templates, templates.size())); + }, exception -> { + log.error("Error listing context management templates", exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Unexpected error listing context management templates", e); + listener.onFailure(e); + } + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 55cdfc67b5..a9e5000def 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -89,6 +89,12 @@ import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.action.connector.TransportCreateConnectorAction; import org.opensearch.ml.action.connector.UpdateConnectorTransportAction; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.action.contextmanagement.ContextManagerFactory; +import org.opensearch.ml.action.contextmanagement.CreateContextManagementTemplateTransportAction; +import org.opensearch.ml.action.contextmanagement.DeleteContextManagementTemplateTransportAction; +import org.opensearch.ml.action.contextmanagement.GetContextManagementTemplateTransportAction; +import org.opensearch.ml.action.contextmanagement.ListContextManagementTemplatesTransportAction; import org.opensearch.ml.action.controller.CreateControllerTransportAction; import org.opensearch.ml.action.controller.DeleteControllerTransportAction; import org.opensearch.ml.action.controller.DeployControllerTransportAction; @@ -191,6 +197,10 @@ import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; import org.opensearch.ml.common.transport.controller.MLControllerGetAction; import org.opensearch.ml.common.transport.controller.MLCreateControllerAction; @@ -323,15 +333,16 @@ import org.opensearch.ml.processor.MLInferenceIngestProcessor; import org.opensearch.ml.processor.MLInferenceSearchRequestProcessor; import org.opensearch.ml.processor.MLInferenceSearchResponseProcessor; -import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.rest.RestMLAddMemoriesAction; import org.opensearch.ml.rest.RestMLCancelBatchJobAction; import org.opensearch.ml.rest.RestMLCreateConnectorAction; +import org.opensearch.ml.rest.RestMLCreateContextManagementTemplateAction; import org.opensearch.ml.rest.RestMLCreateControllerAction; import org.opensearch.ml.rest.RestMLCreateMemoryContainerAction; import org.opensearch.ml.rest.RestMLCreateSessionAction; import org.opensearch.ml.rest.RestMLDeleteAgentAction; import org.opensearch.ml.rest.RestMLDeleteConnectorAction; +import org.opensearch.ml.rest.RestMLDeleteContextManagementTemplateAction; import org.opensearch.ml.rest.RestMLDeleteControllerAction; import org.opensearch.ml.rest.RestMLDeleteMemoriesByQueryAction; import org.opensearch.ml.rest.RestMLDeleteMemoryAction; @@ -345,6 +356,7 @@ import org.opensearch.ml.rest.RestMLGetAgentAction; import org.opensearch.ml.rest.RestMLGetConfigAction; import org.opensearch.ml.rest.RestMLGetConnectorAction; +import org.opensearch.ml.rest.RestMLGetContextManagementTemplateAction; import org.opensearch.ml.rest.RestMLGetControllerAction; import org.opensearch.ml.rest.RestMLGetIndexInsightAction; import org.opensearch.ml.rest.RestMLGetIndexInsightConfigAction; @@ -354,6 +366,7 @@ import org.opensearch.ml.rest.RestMLGetModelGroupAction; import org.opensearch.ml.rest.RestMLGetTaskAction; import org.opensearch.ml.rest.RestMLGetToolAction; +import org.opensearch.ml.rest.RestMLListContextManagementTemplatesAction; import org.opensearch.ml.rest.RestMLListToolsAction; import org.opensearch.ml.rest.RestMLPredictionAction; import org.opensearch.ml.rest.RestMLPredictionStreamAction; @@ -451,6 +464,7 @@ import org.opensearch.watcher.ResourceWatcherService; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import lombok.SneakyThrows; @@ -621,7 +635,11 @@ public MachineLearningPlugin() {} new ActionHandler<>(MLMcpToolsListAction.INSTANCE, TransportMcpToolsListAction.class), new ActionHandler<>(MLMcpToolsUpdateAction.INSTANCE, TransportMcpToolsUpdateAction.class), new ActionHandler<>(MLMcpToolsUpdateOnNodesAction.INSTANCE, TransportMcpToolsUpdateOnNodesAction.class), - new ActionHandler<>(MLMcpServerAction.INSTANCE, TransportMcpServerAction.class) + new ActionHandler<>(MLMcpServerAction.INSTANCE, TransportMcpServerAction.class), + new ActionHandler<>(MLCreateContextManagementTemplateAction.INSTANCE, CreateContextManagementTemplateTransportAction.class), + new ActionHandler<>(MLGetContextManagementTemplateAction.INSTANCE, GetContextManagementTemplateTransportAction.class), + new ActionHandler<>(MLListContextManagementTemplatesAction.INSTANCE, ListContextManagementTemplatesTransportAction.class), + new ActionHandler<>(MLDeleteContextManagementTemplateAction.INSTANCE, DeleteContextManagementTemplateTransportAction.class) ); } @@ -789,6 +807,17 @@ public Collection createComponents( nodeHelper, mlEngine ); + // Create context management services + ContextManagementTemplateService contextManagementTemplateService = new ContextManagementTemplateService( + mlIndicesHandler, + client, + clusterService + ); + ContextManagerFactory contextManagerFactory = new ContextManagerFactory( + new org.opensearch.ml.common.contextmanager.ActivationRuleFactory(), + client + ); + mlExecuteTaskRunner = new MLExecuteTaskRunner( threadPool, clusterService, @@ -799,7 +828,9 @@ public Collection createComponents( mlTaskDispatcher, mlCircuitBreakerService, nodeHelper, - mlEngine + mlEngine, + contextManagementTemplateService, + contextManagerFactory ); // Register thread-safe ML objects here. @@ -1091,6 +1122,15 @@ public List getRestHandlers( RestMLMcpToolsRemoveAction restMLRemoveMcpToolsAction = new RestMLMcpToolsRemoveAction(clusterService, mlFeatureEnabledSetting); RestMLMcpToolsListAction restMLListMcpToolsAction = new RestMLMcpToolsListAction(mlFeatureEnabledSetting); RestMLMcpToolsUpdateAction restMLMcpToolsUpdateAction = new RestMLMcpToolsUpdateAction(clusterService, mlFeatureEnabledSetting); + RestMLCreateContextManagementTemplateAction restMLCreateContextManagementTemplateAction = + new RestMLCreateContextManagementTemplateAction(mlFeatureEnabledSetting); + RestMLGetContextManagementTemplateAction restMLGetContextManagementTemplateAction = new RestMLGetContextManagementTemplateAction( + mlFeatureEnabledSetting + ); + RestMLListContextManagementTemplatesAction restMLListContextManagementTemplatesAction = + new RestMLListContextManagementTemplatesAction(mlFeatureEnabledSetting); + RestMLDeleteContextManagementTemplateAction restMLDeleteContextManagementTemplateAction = + new RestMLDeleteContextManagementTemplateAction(mlFeatureEnabledSetting); return ImmutableList .of( restMLStatsAction, @@ -1167,7 +1207,11 @@ public List getRestHandlers( restMLListMcpToolsAction, restMLMcpToolsUpdateAction, restMLPutIndexInsightConfigAction, - restMLGetIndexInsightConfigAction + restMLGetIndexInsightConfigAction, + restMLCreateContextManagementTemplateAction, + restMLGetContextManagementTemplateAction, + restMLListContextManagementTemplatesAction, + restMLDeleteContextManagementTemplateAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java b/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java index bb55f3fb21..a2cca4bbaa 100644 --- a/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java +++ b/plugin/src/main/java/org/opensearch/ml/resources/MLResourceSharingExtension.java @@ -19,7 +19,17 @@ public class MLResourceSharingExtension implements ResourceSharingExtension { @Override public Set getResourceProviders() { - return Set.of(new ResourceProvider(ML_MODEL_GROUP_RESOURCE_TYPE, ML_MODEL_GROUP_INDEX)); + return Set.of(new ResourceProvider() { + @Override + public String resourceType() { + return ML_MODEL_GROUP_RESOURCE_TYPE; + } + + @Override + public String resourceIndexName() { + return ML_MODEL_GROUP_INDEX; + } + }); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateAction.java new file mode 100644 index 0000000000..34387ccb81 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateAction.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLCreateContextManagementTemplateAction extends BaseRestHandler { + private static final String ML_CREATE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION = "ml_create_context_management_template_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLCreateContextManagementTemplateAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_CREATE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.PUT, + String.format(Locale.ROOT, "%s/context_management/{%s}", ML_BASE_URI, PARAMETER_TEMPLATE_NAME) + ) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLCreateContextManagementTemplateRequest createRequest = getRequest(request); + return channel -> client + .execute(MLCreateContextManagementTemplateAction.INSTANCE, createRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLCreateContextManagementTemplateRequest from a RestRequest + * + * @param request RestRequest + * @return MLCreateContextManagementTemplateRequest + */ + @VisibleForTesting + MLCreateContextManagementTemplateRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + String templateName = request.param(PARAMETER_TEMPLATE_NAME); + if (templateName == null || templateName.trim().isEmpty()) { + throw new IllegalArgumentException("Template name is required"); + } + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + ContextManagementTemplate template = ContextManagementTemplate.parse(parser); + + // Set the template name from URL parameter + template = template.toBuilder().name(templateName).build(); + + return new MLCreateContextManagementTemplateRequest(templateName, template); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateAction.java new file mode 100644 index 0000000000..1dbde7f216 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateAction.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLDeleteContextManagementTemplateAction extends BaseRestHandler { + private static final String ML_DELETE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION = "ml_delete_context_management_template_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLDeleteContextManagementTemplateAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_DELETE_CONTEXT_MANAGEMENT_TEMPLATE_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.DELETE, + String.format(Locale.ROOT, "%s/context_management/{%s}", ML_BASE_URI, PARAMETER_TEMPLATE_NAME) + ) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLDeleteContextManagementTemplateRequest deleteRequest = getRequest(request); + return channel -> client + .execute(MLDeleteContextManagementTemplateAction.INSTANCE, deleteRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLDeleteContextManagementTemplateRequest from a RestRequest + * + * @param request RestRequest + * @return MLDeleteContextManagementTemplateRequest + */ + @VisibleForTesting + MLDeleteContextManagementTemplateRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + String templateName = request.param(PARAMETER_TEMPLATE_NAME); + if (templateName == null || templateName.trim().isEmpty()) { + throw new IllegalArgumentException("Template name is required"); + } + + return new MLDeleteContextManagementTemplateRequest(templateName); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateAction.java new file mode 100644 index 0000000000..4089d0aac1 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateAction.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLGetContextManagementTemplateAction extends BaseRestHandler { + private static final String ML_GET_CONTEXT_MANAGEMENT_TEMPLATE_ACTION = "ml_get_context_management_template_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLGetContextManagementTemplateAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_GET_CONTEXT_MANAGEMENT_TEMPLATE_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/context_management/{%s}", ML_BASE_URI, PARAMETER_TEMPLATE_NAME) + ) + ); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLGetContextManagementTemplateRequest getRequest = getRequest(request); + return channel -> client.execute(MLGetContextManagementTemplateAction.INSTANCE, getRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLGetContextManagementTemplateRequest from a RestRequest + * + * @param request RestRequest + * @return MLGetContextManagementTemplateRequest + */ + @VisibleForTesting + MLGetContextManagementTemplateRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + String templateName = request.param(PARAMETER_TEMPLATE_NAME); + if (templateName == null || templateName.trim().isEmpty()) { + throw new IllegalArgumentException("Template name is required"); + } + + return new MLGetContextManagementTemplateRequest(templateName); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesAction.java new file mode 100644 index 0000000000..d5020bd5c3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesAction.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.transport.client.node.NodeClient; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLListContextManagementTemplatesAction extends BaseRestHandler { + private static final String ML_LIST_CONTEXT_MANAGEMENT_TEMPLATES_ACTION = "ml_list_context_management_templates_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + + /** + * Constructor + */ + public RestMLListContextManagementTemplatesAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_LIST_CONTEXT_MANAGEMENT_TEMPLATES_ACTION; + } + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/context_management", ML_BASE_URI))); + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLListContextManagementTemplatesRequest listRequest = getRequest(request); + return channel -> client + .execute(MLListContextManagementTemplatesAction.INSTANCE, listRequest, new RestToXContentListener<>(channel)); + } + + /** + * Creates a MLListContextManagementTemplatesRequest from a RestRequest + * + * @param request RestRequest + * @return MLListContextManagementTemplatesRequest + */ + @VisibleForTesting + MLListContextManagementTemplatesRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException("Agent framework is disabled"); + } + + int from = request.paramAsInt("from", 0); + int size = request.paramAsInt("size", 10); + + return new MLListContextManagementTemplatesRequest(from, size); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index 73281e0333..436c659b9d 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -15,10 +15,18 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.action.contextmanagement.ContextManagerFactory; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.contextmanager.ContextManagerHookProvider; +import org.opensearch.ml.common.hooks.HookRegistry; import org.opensearch.ml.common.input.Input; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; import org.opensearch.ml.common.transport.execute.MLExecuteStreamTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; @@ -50,6 +58,8 @@ public class MLExecuteTaskRunner extends MLTaskRunner wrappedListener = ActionListener.runBefore(listener, ) Input input = request.getInput(); FunctionName functionName = request.getFunctionName(); + + // Handle agent execution with context management + if (FunctionName.AGENT.equals(functionName) && input instanceof AgentMLInput) { + AgentMLInput agentInput = (AgentMLInput) input; + String contextManagementName = getEffectiveContextManagementName(agentInput); + + if (contextManagementName != null && !contextManagementName.trim().isEmpty()) { + // Execute agent with context management + executeAgentWithContextManagement(request, contextManagementName, channel, listener); + return; + } + } + if (FunctionName.METRICS_CORRELATION.equals(functionName)) { if (!isPythonModelEnabled) { Exception exception = new IllegalArgumentException("This algorithm is not enabled from settings"); @@ -163,10 +189,17 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener { - MLExecuteTaskResponse response = new MLExecuteTaskResponse(functionName, output); - listener.onResponse(response); - }, e -> { listener.onFailure(e); }), channel); + + // Default execution for all functions (including agents without context management) + try { + mlEngine.execute(input, ActionListener.wrap(output -> { + MLExecuteTaskResponse response = new MLExecuteTaskResponse(functionName, output); + listener.onResponse(response); + }, e -> { listener.onFailure(e); }), channel); + } catch (Exception e) { + log.error("Failed to execute ML function", e); + listener.onFailure(e); + } } catch (Exception e) { mlStats .createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) @@ -178,4 +211,218 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener listener + ) { + log.debug("Executing agent with context management: {}", contextManagementName); + + // Lookup context management template + contextManagementTemplateService.getTemplate(contextManagementName, ActionListener.wrap(template -> { + if (template == null) { + listener.onFailure(new IllegalArgumentException("Context management template not found: " + contextManagementName)); + return; + } + + try { + // Create context managers from template + java.util.List contextManagers = createContextManagers(template); + + // Create HookRegistry with context managers + HookRegistry hookRegistry = createHookRegistry(contextManagers, template); + + // Set hook registry in agent input + AgentMLInput agentInput = (AgentMLInput) request.getInput(); + agentInput.setHookRegistry(hookRegistry); + + log + .info( + "Executing agent with context management template: {} using {} context managers", + contextManagementName, + contextManagers.size() + ); + + // Execute agent with hook registry + try { + mlEngine.execute(request.getInput(), ActionListener.wrap(output -> { + log.info("Agent execution completed successfully with context management"); + MLExecuteTaskResponse response = new MLExecuteTaskResponse(request.getFunctionName(), output); + listener.onResponse(response); + }, error -> { + log.error("Agent execution failed with context management", error); + listener.onFailure(error); + }), channel); + } catch (Exception e) { + log.error("Failed to execute agent with context management", e); + listener.onFailure(e); + } + + } catch (Exception e) { + log.error("Failed to create context managers from template: {}", contextManagementName, e); + listener.onFailure(e); + } + }, error -> { + log.error("Failed to retrieve context management template: {}", contextManagementName, error); + listener.onFailure(error); + })); + } + + /** + * Gets the effective context management name for an agent. + * Priority: 1) Runtime parameter from execution request, 2) Agent's stored configuration, 3) Runtime parameters set by MLAgentExecutor + * This follows the same pattern as MCP connectors. + * + * @param agentInput the agent ML input + * @return the effective context management name, or null if none configured + */ + private String getEffectiveContextManagementName(AgentMLInput agentInput) { + // Priority 1: Runtime parameter from execution request (user override) + String runtimeContextManagementName = agentInput.getContextManagementName(); + if (runtimeContextManagementName != null && !runtimeContextManagementName.trim().isEmpty()) { + log.debug("Using runtime context management name: {}", runtimeContextManagementName); + return runtimeContextManagementName; + } + + // Priority 2: Check agent's stored configuration directly + String agentId = agentInput.getAgentId(); + if (agentId != null) { + try { + // Use a blocking call to get the agent synchronously + // This is acceptable here since we're in the task execution path + java.util.concurrent.CompletableFuture future = new java.util.concurrent.CompletableFuture<>(); + + try ( + org.opensearch.common.util.concurrent.ThreadContext.StoredContext context = client + .threadPool() + .getThreadContext() + .stashContext() + ) { + client + .get( + new org.opensearch.action.get.GetRequest(org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX, agentId), + org.opensearch.core.action.ActionListener.runBefore(org.opensearch.core.action.ActionListener.wrap(response -> { + if (response.isExists()) { + try { + org.opensearch.core.xcontent.XContentParser parser = + org.opensearch.common.xcontent.json.JsonXContent.jsonXContent + .createParser( + null, + org.opensearch.common.xcontent.LoggingDeprecationHandler.INSTANCE, + response.getSourceAsString() + ); + org.opensearch.core.xcontent.XContentParserUtils + .ensureExpectedToken( + org.opensearch.core.xcontent.XContentParser.Token.START_OBJECT, + parser.nextToken(), + parser + ); + org.opensearch.ml.common.agent.MLAgent mlAgent = org.opensearch.ml.common.agent.MLAgent + .parse(parser); + + if (mlAgent.hasContextManagementTemplate()) { + String templateName = mlAgent.getContextManagementTemplateName(); + future.complete(templateName); + } else { + future.complete(null); + } + } catch (Exception e) { + future.completeExceptionally(e); + } + } else { + future.complete(null); // Agent not found + } + }, future::completeExceptionally), context::restore) + ); + } + + // Wait for the result with a timeout + String contextManagementName = future.get(5, java.util.concurrent.TimeUnit.SECONDS); + if (contextManagementName != null && !contextManagementName.trim().isEmpty()) { + return contextManagementName; + } + } catch (Exception e) { + // Continue to fallback methods + } + } + + // Priority 3: Agent's runtime parameters (set by MLAgentExecutor in input parameters) + if (agentInput.getInputDataset() instanceof org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet) { + org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet dataset = + (org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet) agentInput.getInputDataset(); + + // Check if context management has already been processed by MLAgentExecutor (for inline templates) + String contextManagementProcessed = dataset.getParameters().get("context_management_processed"); + if ("true".equals(contextManagementProcessed)) { + log.debug("Context management already processed by MLAgentExecutor, skipping MLExecuteTaskRunner processing"); + return null; // Skip processing in MLExecuteTaskRunner + } + + // Handle template references (not processed by MLAgentExecutor) + String agentContextManagementName = dataset.getParameters().get("context_management"); + if (agentContextManagementName != null && !agentContextManagementName.trim().isEmpty()) { + return agentContextManagementName; + } + } + + return null; + } + + /** + * Create context managers from template configuration + */ + private java.util.List createContextManagers(ContextManagementTemplate template) { + java.util.List contextManagers = new java.util.ArrayList<>(); + + // Iterate through all hooks in the template + for (java.util.Map.Entry> entry : template.getHooks().entrySet()) { + String hookName = entry.getKey(); + java.util.List configs = entry.getValue(); + + for (ContextManagerConfig config : configs) { + try { + ContextManager manager = contextManagerFactory.createContextManager(config); + if (manager != null) { + contextManagers.add(manager); + log.debug("Created context manager: {} for hook: {}", config.getType(), hookName); + } else { + log.warn("Failed to create context manager of type: {}", config.getType()); + } + } catch (Exception e) { + log.error("Error creating context manager of type: {}", config.getType(), e); + // Continue with other managers + } + } + } + + log.info("Created {} context managers from template: {}", contextManagers.size(), template.getName()); + return contextManagers; + } + + /** + * Create HookRegistry with context managers + */ + private HookRegistry createHookRegistry(java.util.List contextManagers, ContextManagementTemplate template) { + HookRegistry hookRegistry = new HookRegistry(); + + if (!contextManagers.isEmpty()) { + // Create context manager hook provider + ContextManagerHookProvider hookProvider = new ContextManagerHookProvider(contextManagers); + + // Update hook configuration based on template + hookProvider.updateHookConfiguration(template.getHooks()); + + // Register hooks + hookProvider.registerHooks(hookRegistry); + + log.debug("Registered context manager hooks for {} managers", contextManagers.size()); + } + + return hookRegistry; + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index e5c11215c3..acbd70bb43 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -76,6 +76,7 @@ public class RestActionUtils { public static final String[] UI_METADATA_EXCLUDE = new String[] { "ui_metadata" }; public static final String PARAMETER_TOOL_NAME = "tool_name"; + public static final String PARAMETER_TEMPLATE_NAME = "template_name"; public static final String OPENDISTRO_SECURITY_CONFIG_PREFIX = "_opendistro_security_"; diff --git a/plugin/src/test/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidatorTests.java b/plugin/src/test/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidatorTests.java new file mode 100644 index 0000000000..24009b5094 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidatorTests.java @@ -0,0 +1,413 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.agent; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.exception.MLResourceNotFoundException; + +import junit.framework.TestCase; + +public class MLAgentRegistrationValidatorTests extends TestCase { + + private ContextManagementTemplateService mockTemplateService; + private MLAgentRegistrationValidator validator; + + @Before + public void setUp() { + mockTemplateService = mock(ContextManagementTemplateService.class); + validator = new MLAgentRegistrationValidator(mockTemplateService); + } + + @Test + public void testValidateAgentForRegistration_NoContextManagement() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertTrue(result.get()); + assertNull(error.get()); + + // Verify template service was not called since no template reference + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateAgentForRegistration_TemplateExists() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("existing_template").build(); + + // Mock template service to return a template (exists) + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + ContextManagementTemplate template = ContextManagementTemplate.builder().name("existing_template").build(); + listener.onResponse(template); + return null; + }).when(mockTemplateService).getTemplate(eq("existing_template"), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertTrue(result.get()); + assertNull(error.get()); + + verify(mockTemplateService).getTemplate(eq("existing_template"), any()); + } + + @Test + public void testValidateAgentForRegistration_TemplateNotFound() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("nonexistent_template").build(); + + // Mock template service to return template not found + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new MLResourceNotFoundException("Context management template not found: nonexistent_template")); + return null; + }).when(mockTemplateService).getTemplate(eq("nonexistent_template"), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue(error.get().getMessage().contains("Context management template not found: nonexistent_template")); + + verify(mockTemplateService).getTemplate(eq("nonexistent_template"), any()); + } + + @Test + public void testValidateAgentForRegistration_TemplateServiceError() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("error_template").build(); + + // Mock template service to return an error + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("Service error")); + return null; + }).when(mockTemplateService).getTemplate(eq("error_template"), any()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue(error.get().getMessage().contains("Failed to validate context management template")); + + verify(mockTemplateService).getTemplate(eq("error_template"), any()); + } + + @Test + public void testValidateAgentForRegistration_InlineContextManagement() throws InterruptedException { + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("inline_template").hooks(hooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertTrue(result.get()); + assertNull(error.get()); + + // Verify template service was not called since using inline configuration + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateAgentForRegistration_InvalidTemplateName() throws InterruptedException { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("invalid@name").build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue( + error + .get() + .getMessage() + .contains("Context management template name can only contain letters, numbers, underscores, hyphens, and dots") + ); + + // Verify template service was not called due to early validation failure + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateAgentForRegistration_InvalidInlineConfiguration() throws InterruptedException { + Map> invalidHooks = new HashMap<>(); + invalidHooks.put("INVALID_HOOK", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("test_template").hooks(invalidHooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference result = new AtomicReference<>(); + AtomicReference error = new AtomicReference<>(); + + validator.validateAgentForRegistration(agent, new ActionListener() { + @Override + public void onResponse(Boolean response) { + result.set(response); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + error.set(e); + latch.countDown(); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertNull(result.get()); + assertNotNull(error.get()); + assertTrue(error.get() instanceof IllegalArgumentException); + assertTrue(error.get().getMessage().contains("Invalid hook name: INVALID_HOOK")); + + // Verify template service was not called due to early validation failure + verify(mockTemplateService, never()).getTemplate(any(), any()); + } + + @Test + public void testValidateContextManagementConfiguration_ValidTemplateName() { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("valid_template_name").build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNull(result); + } + + @Test + public void testValidateContextManagementConfiguration_ValidInlineConfig() { + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("inline_template").hooks(hooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNull(result); + } + + @Test + public void testValidateContextManagementConfiguration_EmptyTemplateName() { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("").build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Context management template name cannot be null or empty")); + } + + @Test + public void testValidateContextManagementConfiguration_TooLongTemplateName() { + String longName = "a".repeat(257); + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName(longName).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Context management template name cannot exceed 256 characters")); + } + + @Test + public void testValidateContextManagementConfiguration_InvalidTemplateNameCharacters() { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagementName("invalid@name#").build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Context management template name can only contain letters, numbers, underscores, hyphens, and dots")); + } + + @Test + public void testValidateContextManagementConfiguration_InvalidHookName() { + Map> invalidHooks = new HashMap<>(); + invalidHooks.put("INVALID_HOOK", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("test_template").hooks(invalidHooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Invalid hook name: INVALID_HOOK")); + } + + @Test + public void testValidateContextManagementConfiguration_EmptyHookConfigs() { + Map> emptyHooks = new HashMap<>(); + emptyHooks.put("POST_TOOL", List.of()); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("test_template").hooks(emptyHooks).build(); + + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build(); + + String result = validator.validateContextManagementConfiguration(agent); + assertNotNull(result); + assertTrue(result.contains("Hook POST_TOOL must have at least one context manager configuration")); + } + + @Test + public void testValidateContextManagementConfiguration_Conflict() { + // This test should verify that the MLAgent constructor throws an exception + // when both context management name and inline config are provided + Map> hooks = new HashMap<>(); + hooks.put("POST_TOOL", List.of(new ContextManagerConfig("ToolsOutputTruncateManager", null, null))); + + ContextManagementTemplate contextManagement = ContextManagementTemplate.builder().name("inline_template").hooks(hooks).build(); + + try { + MLAgent agent = MLAgent + .builder() + .name("test_agent") + .type("flow") + .contextManagementName("template_name") + .contextManagement(contextManagement) + .build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertEquals("Cannot specify both context_management_name and context_management", e.getMessage()); + } + } + + @Test + public void testValidateContextManagementConfiguration_InvalidInlineConfig() { + // This test should verify that the MLAgent constructor throws an exception + // when invalid context management configuration is provided + ContextManagementTemplate invalidContextManagement = ContextManagementTemplate + .builder() + .name("invalid_template") + .hooks(new HashMap<>()) + .build(); + + try { + MLAgent agent = MLAgent.builder().name("test_agent").type("flow").contextManagement(invalidContextManagement).build(); + fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + assertEquals("Invalid context management configuration", e.getMessage()); + } + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java index 63fd5216e2..5ec7f38311 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/DeleteAgentTransportActionTests.java @@ -353,6 +353,8 @@ private GetResponse prepareMLAgent(String agentId, boolean isHidden, String tena Instant.EPOCH, "test", isHidden, + null, // contextManagementName + null, // contextManagement tenantId ); diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java index 8a5e081855..7bed2a7225 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/GetAgentTransportActionTests.java @@ -335,6 +335,8 @@ public GetResponse prepareMLAgent(String agentId, boolean isHidden, String tenan Instant.EPOCH, "test", isHidden, + null, // contextManagementName + null, // contextManagement tenantId ); diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java index b9e7323b7e..0818845cac 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/RegisterAgentTransportActionTests.java @@ -42,6 +42,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; import org.opensearch.ml.common.MLAgentType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; @@ -96,6 +97,9 @@ public class RegisterAgentTransportActionTests extends OpenSearchTestCase { @Mock private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); @@ -116,7 +120,8 @@ public void setup() throws IOException { sdkClient, mlIndicesHandler, clusterService, - mlFeatureEnabledSetting + mlFeatureEnabledSetting, + contextManagementTemplateService ); indexResponse = new IndexResponse(new ShardId(ML_AGENT_INDEX, "_na_", 0), "AGENT_ID", 1, 0, 2, true); } @@ -510,7 +515,8 @@ public void test_execute_registerAgent_MCPConnectorDisabled() { sdkClient, mlIndicesHandler, clusterService, - mlFeatureEnabledSetting + mlFeatureEnabledSetting, + contextManagementTemplateService ); disabledAction.doExecute(task, request, actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtilsTests.java new file mode 100644 index 0000000000..c8d1c8953f --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementIndexUtilsTests.java @@ -0,0 +1,231 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.AdminClient; +import org.opensearch.transport.client.Client; +import org.opensearch.transport.client.IndicesAdminClient; + +public class ContextManagementIndexUtilsTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ClusterService clusterService; + + @Mock + private ThreadPool threadPool; + + @Mock + private AdminClient adminClient; + + @Mock + private IndicesAdminClient indicesAdminClient; + + private ContextManagementIndexUtils contextManagementIndexUtils; + + @Before + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + + // Create a real ThreadContext instead of mocking it + Settings settings = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(settings); + + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(client.admin()).thenReturn(adminClient); + when(adminClient.indices()).thenReturn(indicesAdminClient); + + contextManagementIndexUtils = new ContextManagementIndexUtils(client, clusterService); + } + + @Test + public void testGetIndexName() { + String indexName = ContextManagementIndexUtils.getIndexName(); + assertEquals("ml_context_management_templates", indexName); + } + + @Test + public void testDoesIndexExist_True() { + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(true); + + boolean exists = contextManagementIndexUtils.doesIndexExist(); + assertTrue(exists); + } + + @Test + public void testDoesIndexExist_False() { + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + boolean exists = contextManagementIndexUtils.doesIndexExist(); + assertFalse(exists); + } + + @Test + public void testCreateIndexIfNotExists_IndexAlreadyExists() { + // Mock index already exists + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(true); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onResponse(true); + verify(indicesAdminClient, never()).create(any(), any()); + } + + @Test + public void testCreateIndexIfNotExists_Success() { + // Mock index doesn't exist + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Mock successful index creation + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener createListener = invocation.getArgument(1); + CreateIndexResponse response = mock(CreateIndexResponse.class); + createListener.onResponse(response); + return null; + }).when(indicesAdminClient).create(any(CreateIndexRequest.class), any()); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onResponse(true); + + // Verify the create request was made with correct settings + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(CreateIndexRequest.class); + verify(indicesAdminClient).create(requestCaptor.capture(), any()); + + CreateIndexRequest request = requestCaptor.getValue(); + assertEquals("ml_context_management_templates", request.index()); + + Settings indexSettings = request.settings(); + assertEquals("1", indexSettings.get("index.number_of_shards")); + assertEquals("1", indexSettings.get("index.number_of_replicas")); + assertEquals("0-1", indexSettings.get("index.auto_expand_replicas")); + } + + @Test + public void testCreateIndexIfNotExists_ResourceAlreadyExistsException() { + // Mock index doesn't exist initially + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + // Mock ResourceAlreadyExistsException (race condition) + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener createListener = invocation.getArgument(1); + createListener.onFailure(new ResourceAlreadyExistsException("Index already exists")); + return null; + }).when(indicesAdminClient).create(any(CreateIndexRequest.class), any()); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onResponse(true); + } + + @Test + public void testCreateIndexIfNotExists_OtherException() { + // Mock index doesn't exist + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + RuntimeException testException = new RuntimeException("Test exception"); + + // Mock other exception + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener createListener = invocation.getArgument(1); + createListener.onFailure(testException); + return null; + }).when(indicesAdminClient).create(any(CreateIndexRequest.class), any()); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onFailure(testException); + } + + @Test + public void testCreateIndexIfNotExists_UnexpectedException() { + // Mock index doesn't exist + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + RuntimeException testException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception during setup + when(client.admin()).thenThrow(testException); + + contextManagementIndexUtils.createIndexIfNotExists(listener); + + verify(listener).onFailure(testException); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateServiceTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateServiceTests.java new file mode 100644 index 0000000000..c8b7391908 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagementTemplateServiceTests.java @@ -0,0 +1,351 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.Client; + +public class ContextManagementTemplateServiceTests extends OpenSearchTestCase { + + @Mock + private MLIndicesHandler mlIndicesHandler; + + @Mock + private Client client; + + @Mock + private ClusterService clusterService; + + @Mock + private ThreadPool threadPool; + + private ContextManagementTemplateService contextManagementTemplateService; + + @Before + public void setUp() throws Exception { + super.setUp(); + MockitoAnnotations.openMocks(this); + + // Create a real ThreadContext instead of mocking it + Settings settings = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(settings); + + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + // Mock cluster service dependencies for proper setup + org.opensearch.cluster.ClusterState clusterState = mock(org.opensearch.cluster.ClusterState.class); + org.opensearch.cluster.metadata.Metadata metadata = mock(org.opensearch.cluster.metadata.Metadata.class); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.hasIndex(anyString())).thenReturn(false); // Default to index not existing + + contextManagementTemplateService = new ContextManagementTemplateService(mlIndicesHandler, client, clusterService); + } + + @Test + public void testConstructor() { + assertNotNull(contextManagementTemplateService); + } + + @Test + public void testSaveTemplate_InvalidTemplate() { + String templateName = "test_template"; + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(false); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate(templateName, template, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Invalid context management template", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_WithPagination() { + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(5, 20, listener); + + // Verify that the method was called - the actual OpenSearch interaction would be complex to mock + // This at least exercises the method signature and basic flow + verify(client).threadPool(); + } + + @Test + public void testListTemplates_DefaultPagination() { + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(listener); + + // Verify that the method was called - this exercises the default pagination path + verify(client).threadPool(); + } + + @Test + public void testSaveTemplate_NullTemplate() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", null, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof NullPointerException); + } + + @Test + public void testSaveTemplate_ValidTemplate() { + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(true); + when(template.getName()).thenReturn("test_template"); + when(template.getCreatedTime()).thenReturn(null); + when(template.getCreatedBy()).thenReturn(null); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + // Verify template validation was called - the method will fail due to complex mocking requirements + // but this covers the validation path and timestamp setting + verify(template).isValid(); + verify(template).getCreatedTime(); + verify(template).getCreatedBy(); + verify(template).setCreatedTime(any(java.time.Instant.class)); + verify(template).setLastModified(any(java.time.Instant.class)); + } + + @Test + public void testSaveTemplate_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenThrow(new RuntimeException("Validation error")); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("Validation error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testGetTemplate_NullTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate(null, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testGetTemplate_EmptyTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate("", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteTemplate_NullTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate(null, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteTemplate_EmptyTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate("", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertEquals("Template name cannot be null or empty", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testGetTemplate_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block by making threadPool throw + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate("test_template", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testDeleteTemplate_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block by making threadPool throw + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate("test_template", listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_ExceptionInTryBlock() { + // Test exception handling in the outer try-catch block by making threadPool throw + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_WithPaginationExceptionInTryBlock() { + // Test exception handling in the outer try-catch block for paginated version + when(client.threadPool()).thenThrow(new RuntimeException("ThreadPool error")); + + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + + contextManagementTemplateService.listTemplates(10, 50, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener).onFailure(exceptionCaptor.capture()); + assertTrue(exceptionCaptor.getValue() instanceof RuntimeException); + assertEquals("ThreadPool error", exceptionCaptor.getValue().getMessage()); + } + + @Test + public void testListTemplates_NullListener() { + // This should not throw an exception, but we can test that the method handles it gracefully + try { + contextManagementTemplateService.listTemplates(null); + // If we get here without exception, that's fine - the method should handle null listeners gracefully + } catch (Exception e) { + // If an exception is thrown, it should be a meaningful one + assertTrue(e instanceof IllegalArgumentException || e instanceof NullPointerException); + } + } + + @Test + public void testGetTemplate_WhitespaceTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.getTemplate(" ", listener); + + // Whitespace is not considered empty by Strings.isNullOrEmpty(), so it will proceed + // This tests the branch where template name is not null/empty but contains only whitespace + verify(client).threadPool(); + } + + @Test + public void testDeleteTemplate_WhitespaceTemplateName() { + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.deleteTemplate(" ", listener); + + // Whitespace is not considered empty by Strings.isNullOrEmpty(), so it will proceed + // This tests the branch where template name is not null/empty but contains only whitespace + verify(client).threadPool(); + } + + @Test + public void testSaveTemplate_TemplateWithExistingCreatedTime() { + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(true); + when(template.getName()).thenReturn("test_template"); + when(template.getCreatedTime()).thenReturn(java.time.Instant.now()); // Already has created time + when(template.getCreatedBy()).thenReturn("existing_user"); // Already has created by + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + // Verify template validation was called and existing values were checked + verify(template).isValid(); + verify(template).getCreatedTime(); + verify(template).getCreatedBy(); + // Should call setLastModified but not setCreatedTime or setCreatedBy since they exist + verify(template).setLastModified(any(java.time.Instant.class)); + verify(template, never()).setCreatedTime(any(java.time.Instant.class)); + verify(template, never()).setCreatedBy(anyString()); + } + + @Test + public void testSaveTemplate_TemplateWithNullCreatedBy() { + ContextManagementTemplate template = mock(ContextManagementTemplate.class); + when(template.isValid()).thenReturn(true); + when(template.getName()).thenReturn("test_template"); + when(template.getCreatedTime()).thenReturn(null); + when(template.getCreatedBy()).thenReturn(null); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + + contextManagementTemplateService.saveTemplate("test_template", template, listener); + + // Verify template validation was called + verify(template).isValid(); + verify(template).getCreatedTime(); + verify(template).getCreatedBy(); + // Should set both created time and last modified + verify(template).setCreatedTime(any(java.time.Instant.class)); + verify(template).setLastModified(any(java.time.Instant.class)); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java new file mode 100644 index 0000000000..1e0661c80b --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.common.contextmanager.ActivationRuleFactory; +import org.opensearch.ml.common.contextmanager.ContextManager; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.engine.algorithms.contextmanager.SlidingWindowManager; +import org.opensearch.ml.engine.algorithms.contextmanager.SummarizationManager; +import org.opensearch.ml.engine.algorithms.contextmanager.ToolsOutputTruncateManager; +import org.opensearch.transport.client.Client; + +public class ContextManagerFactoryTests { + + private ContextManagerFactory contextManagerFactory; + private ActivationRuleFactory activationRuleFactory; + private Client client; + + @Before + public void setUp() { + activationRuleFactory = mock(ActivationRuleFactory.class); + client = mock(Client.class); + contextManagerFactory = new ContextManagerFactory(activationRuleFactory, client); + } + + @Test + public void testCreateContextManager_ToolsOutputTruncateManager() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("ToolsOutputTruncateManager", null, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof ToolsOutputTruncateManager); + } + + @Test + public void testCreateContextManager_ToolsOutputTruncateManagerWithParameters() { + // Arrange + Map parameters = Map.of("maxLength", 1000); + ContextManagerConfig config = new ContextManagerConfig("ToolsOutputTruncateManager", parameters, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof ToolsOutputTruncateManager); + } + + @Test + public void testCreateContextManager_SlidingWindowManager() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("SlidingWindowManager", null, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof SlidingWindowManager); + } + + @Test + public void testCreateContextManager_SummarizationManager() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("SummarizationManager", null, null); + + // Act + ContextManager contextManager = contextManagerFactory.createContextManager(config); + + // Assert + assertNotNull(contextManager); + assertTrue(contextManager instanceof SummarizationManager); + } + + @Test + public void testCreateContextManager_UnknownType() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("UnknownManager", null, null); + + // Act & Assert + try { + contextManagerFactory.createContextManager(config); + fail("Expected IllegalArgumentException for unknown manager type"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Unsupported context manager type")); + } + } + + @Test + public void testCreateContextManager_NullConfig() { + // Act & Assert + try { + contextManagerFactory.createContextManager(null); + fail("Expected IllegalArgumentException for null config"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("cannot be null")); + } + } + + @Test + public void testCreateContextManager_NullType() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig(null, null, null); + + // Act & Assert + try { + contextManagerFactory.createContextManager(config); + fail("Expected IllegalArgumentException for null type"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("cannot be null")); + } + } + + @Test + public void testCreateContextManager_EmptyType() { + // Arrange + ContextManagerConfig config = new ContextManagerConfig("", null, null); + + // Act & Assert + try { + contextManagerFactory.createContextManager(config); + fail("Expected IllegalArgumentException for empty type"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Unsupported context manager type")); + } + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportActionTests.java new file mode 100644 index 0000000000..859cc1fbd7 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/CreateContextManagementTemplateTransportActionTests.java @@ -0,0 +1,196 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.Collections; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class CreateContextManagementTemplateTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private CreateContextManagementTemplateTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new CreateContextManagementTemplateTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLCreateContextManagementTemplateResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLCreateContextManagementTemplateResponse response = responseCaptor.getValue(); + assertEquals(templateName, response.getTemplateName()); + assertEquals("created", response.getStatus()); + } + + @Test + public void testDoExecute_SaveFailure() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock failed template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onResponse(false); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + + RuntimeException exception = exceptionCaptor.getValue(); + assertEquals("Failed to create context management template", exception.getMessage()); + } + + @Test + public void testDoExecute_SaveException() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException saveException = new RuntimeException("Database error"); + + // Mock exception during template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onFailure(saveException); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(saveException); + } + + @Test + public void testDoExecute_UnexpectedException() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).saveTemplate(any(), any(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLCreateContextManagementTemplateRequest request = new MLCreateContextManagementTemplateRequest(templateName, template); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template save + doAnswer(invocation -> { + ActionListener saveListener = invocation.getArgument(2); + saveListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).saveTemplate(eq(templateName), eq(template), any()); + } + + private ContextManagementTemplate createTestTemplate() { + Map config = Collections.singletonMap("summary_ratio", 0.3); + ContextManagerConfig contextManagerConfig = new ContextManagerConfig("SummarizationManager", null, config); + + return ContextManagementTemplate + .builder() + .name("test_template") + .description("Test template") + .hooks(Collections.singletonMap("PreLLMEvent", Collections.singletonList(contextManagerConfig))) + .build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportActionTests.java new file mode 100644 index 0000000000..a160fabc59 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/DeleteContextManagementTemplateTransportActionTests.java @@ -0,0 +1,174 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class DeleteContextManagementTemplateTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private DeleteContextManagementTemplateTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new DeleteContextManagementTemplateTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLDeleteContextManagementTemplateResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLDeleteContextManagementTemplateResponse response = responseCaptor.getValue(); + assertEquals(templateName, response.getTemplateName()); + assertEquals("deleted", response.getStatus()); + } + + @Test + public void testDoExecute_DeleteFailure() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock failed template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onResponse(false); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + + RuntimeException exception = exceptionCaptor.getValue(); + assertEquals("Context management template not found: test_template", exception.getMessage()); + } + + @Test + public void testDoExecute_ServiceException() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException serviceException = new RuntimeException("Database error"); + + // Mock exception during template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onFailure(serviceException); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(serviceException); + } + + @Test + public void testDoExecute_UnexpectedException() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).deleteTemplate(any(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + String templateName = "test_template"; + MLDeleteContextManagementTemplateRequest request = new MLDeleteContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template deletion + doAnswer(invocation -> { + ActionListener deleteListener = invocation.getArgument(1); + deleteListener.onResponse(true); + return null; + }).when(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).deleteTemplate(eq(templateName), any()); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportActionTests.java new file mode 100644 index 0000000000..4bb1328518 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/GetContextManagementTemplateTransportActionTests.java @@ -0,0 +1,192 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.Collections; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class GetContextManagementTemplateTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private GetContextManagementTemplateTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new GetContextManagementTemplateTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template retrieval + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onResponse(template); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLGetContextManagementTemplateResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLGetContextManagementTemplateResponse response = responseCaptor.getValue(); + assertEquals(template, response.getTemplate()); + } + + @Test + public void testDoExecute_TemplateNotFound() { + String templateName = "nonexistent_template"; + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock template not found (null response) + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onResponse(null); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(listener).onFailure(exceptionCaptor.capture()); + + RuntimeException exception = exceptionCaptor.getValue(); + assertEquals("Context management template not found: " + templateName, exception.getMessage()); + } + + @Test + public void testDoExecute_ServiceException() { + String templateName = "test_template"; + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException serviceException = new RuntimeException("Database error"); + + // Mock exception during template retrieval + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onFailure(serviceException); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(serviceException); + } + + @Test + public void testDoExecute_UnexpectedException() { + String templateName = "test_template"; + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).getTemplate(any(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + String templateName = "test_template"; + ContextManagementTemplate template = createTestTemplate(); + MLGetContextManagementTemplateRequest request = new MLGetContextManagementTemplateRequest(templateName); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + // Mock successful template retrieval + doAnswer(invocation -> { + ActionListener getListener = invocation.getArgument(1); + getListener.onResponse(template); + return null; + }).when(contextManagementTemplateService).getTemplate(eq(templateName), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).getTemplate(eq(templateName), any()); + } + + private ContextManagementTemplate createTestTemplate() { + Map config = Collections.singletonMap("summary_ratio", 0.3); + ContextManagerConfig contextManagerConfig = new ContextManagerConfig("SummarizationManager", null, config); + + return ContextManagementTemplate + .builder() + .name("test_template") + .description("Test template") + .hooks(Collections.singletonMap("PreLLMEvent", Collections.singletonList(contextManagerConfig))) + .build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportActionTests.java new file mode 100644 index 0000000000..c5951ee868 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ListContextManagementTemplatesTransportActionTests.java @@ -0,0 +1,235 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.contextmanagement; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.contextmanager.ContextManagementTemplate; +import org.opensearch.ml.common.contextmanager.ContextManagerConfig; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesResponse; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class ListContextManagementTemplatesTransportActionTests extends OpenSearchTestCase { + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private ContextManagementTemplateService contextManagementTemplateService; + + @InjectMocks + private ListContextManagementTemplatesTransportAction transportAction; + + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + transportAction = new ListContextManagementTemplatesTransportAction( + transportService, + actionFilters, + client, + contextManagementTemplateService + ); + } + + @Test + public void testDoExecute_Success() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List templates = Arrays.asList(createTestTemplate("template1"), createTestTemplate("template2")); + + // Mock successful template listing + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(templates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLListContextManagementTemplatesResponse response = responseCaptor.getValue(); + assertEquals(templates, response.getTemplates()); + } + + @Test + public void testDoExecute_EmptyList() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List emptyTemplates = Collections.emptyList(); + + // Mock empty template list + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(emptyTemplates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLListContextManagementTemplatesResponse response = responseCaptor.getValue(); + assertEquals(emptyTemplates, response.getTemplates()); + assertTrue(response.getTemplates().isEmpty()); + } + + @Test + public void testDoExecute_ServiceException() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException serviceException = new RuntimeException("Database error"); + + // Mock exception during template listing + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onFailure(serviceException); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(serviceException); + } + + @Test + public void testDoExecute_UnexpectedException() { + int from = 0; + int size = 10; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + RuntimeException unexpectedException = new RuntimeException("Unexpected error"); + + // Mock unexpected exception + doThrow(unexpectedException).when(contextManagementTemplateService).listTemplates(anyInt(), anyInt(), any()); + + transportAction.doExecute(task, request, listener); + + verify(listener).onFailure(unexpectedException); + } + + @Test + public void testDoExecute_VerifyServiceCall() { + int from = 5; + int size = 20; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List templates = Arrays.asList(createTestTemplate("template1")); + + // Mock successful template listing + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(templates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + // Verify the service was called with correct parameters + verify(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + } + + @Test + public void testDoExecute_CustomPagination() { + int from = 10; + int size = 5; + MLListContextManagementTemplatesRequest request = new MLListContextManagementTemplatesRequest(from, size); + Task task = mock(Task.class); + ActionListener listener = mock(ActionListener.class); + + List templates = Arrays.asList(createTestTemplate("template3")); + + // Mock successful template listing with custom pagination + doAnswer(invocation -> { + ActionListener> listListener = invocation.getArgument(2); + listListener.onResponse(templates); + return null; + }).when(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + + transportAction.doExecute(task, request, listener); + + ArgumentCaptor responseCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesResponse.class); + verify(listener).onResponse(responseCaptor.capture()); + + MLListContextManagementTemplatesResponse response = responseCaptor.getValue(); + assertEquals(templates, response.getTemplates()); + + // Verify the service was called with custom pagination parameters + verify(contextManagementTemplateService).listTemplates(eq(from), eq(size), any()); + } + + private ContextManagementTemplate createTestTemplate(String name) { + Map config = Collections.singletonMap("summary_ratio", 0.3); + ContextManagerConfig contextManagerConfig = new ContextManagerConfig("SummarizationManager", null, config); + + return ContextManagementTemplate + .builder() + .name(name) + .description("Test template " + name) + .hooks(Collections.singletonMap("PreLLMEvent", Collections.singletonList(contextManagerConfig))) + .build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateActionTests.java new file mode 100644 index 0000000000..ea585e0459 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateContextManagementTemplateActionTests.java @@ -0,0 +1,216 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLCreateContextManagementTemplateRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLCreateContextManagementTemplateActionTests extends OpenSearchTestCase { + private RestMLCreateContextManagementTemplateAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLCreateContextManagementTemplateAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLCreateContextManagementTemplateAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLCreateContextManagementTemplateAction action = new RestMLCreateContextManagementTemplateAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_create_context_management_template_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/context_management/{template_name}", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLCreateContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLCreateContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLCreateContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLCreateContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithMissingTemplateName() { + Map params = new HashMap<>(); + // No template name parameter + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithEmptyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, ""); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithWhitespaceOnlyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithInvalidJsonContent() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray("invalid json"), XContentType.JSON) + .build(); + + assertThrows(Exception.class, () -> restAction.getRequest(request)); + } + + public void testGetRequestWithValidInput() throws Exception { + RestRequest request = getRestRequest(); + MLCreateContextManagementTemplateRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals("test_template", result.getTemplateName()); + assertNotNull(result.getTemplate()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(params) + .withContent(new BytesArray(getValidTemplateContent()), XContentType.JSON) + .build(); + } + + private String getValidTemplateContent() { + return "{\n" + + " \"description\": \"Test template\",\n" + + " \"hooks\": {\n" + + " \"PreLLMEvent\": [\n" + + " {\n" + + " \"type\": \"SummarizationManager\",\n" + + " \"config\": {\n" + + " \"summary_ratio\": 0.3,\n" + + " \"preserve_recent_messages\": 10\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateActionTests.java new file mode 100644 index 0000000000..520dd05d19 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteContextManagementTemplateActionTests.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLDeleteContextManagementTemplateRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLDeleteContextManagementTemplateActionTests extends OpenSearchTestCase { + private RestMLDeleteContextManagementTemplateAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLDeleteContextManagementTemplateAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLDeleteContextManagementTemplateAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLDeleteContextManagementTemplateAction action = new RestMLDeleteContextManagementTemplateAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_delete_context_management_template_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.DELETE, route.getMethod()); + assertEquals("/_plugins/_ml/context_management/{template_name}", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLDeleteContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLDeleteContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithMissingTemplateName() { + Map params = new HashMap<>(); + // No template name parameter + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithEmptyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, ""); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithWhitespaceTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLDeleteContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLDeleteContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithWhitespaceTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + } + + public void testGetRequestWithValidInput() throws Exception { + RestRequest request = getRestRequest(); + MLDeleteContextManagementTemplateRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals("test_template", result.getTemplateName()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateActionTests.java new file mode 100644 index 0000000000..abe0d3edaa --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetContextManagementTemplateActionTests.java @@ -0,0 +1,173 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_TEMPLATE_NAME; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateAction; +import org.opensearch.ml.common.transport.contextmanagement.MLGetContextManagementTemplateRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLGetContextManagementTemplateActionTests extends OpenSearchTestCase { + private RestMLGetContextManagementTemplateAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLGetContextManagementTemplateAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLGetContextManagementTemplateAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLGetContextManagementTemplateAction action = new RestMLGetContextManagementTemplateAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_get_context_management_template_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/context_management/{template_name}", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLGetContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLGetContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithMissingTemplateName() { + Map params = new HashMap<>(); + // No template name parameter + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithEmptyTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, ""); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testGetRequestWithWhitespaceTemplateName() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, " "); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> restAction.getRequest(request)); + assertEquals("Template name is required", exception.getMessage()); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLGetContextManagementTemplateRequest.class); + verify(client, times(1)).execute(eq(MLGetContextManagementTemplateAction.INSTANCE), argumentCaptor.capture(), any()); + String templateName = argumentCaptor.getValue().getTemplateName(); + assertEquals("test_template", templateName); + } + + public void testGetRequestWithValidInput() throws Exception { + RestRequest request = getRestRequest(); + MLGetContextManagementTemplateRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals("test_template", result.getTemplateName()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_TEMPLATE_NAME, "test_template"); + + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesActionTests.java new file mode 100644 index 0000000000..d1f56a934a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLListContextManagementTemplatesActionTests.java @@ -0,0 +1,184 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesAction; +import org.opensearch.ml.common.transport.contextmanagement.MLListContextManagementTemplatesRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLListContextManagementTemplatesActionTests extends OpenSearchTestCase { + private RestMLListContextManagementTemplatesAction restAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private RestChannel channel; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + restAction = new RestMLListContextManagementTemplatesAction(mlFeatureEnabledSetting); + + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + doAnswer(invocation -> { + // Mock successful execution - actionListener not used in this test + return null; + }).when(client).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLListContextManagementTemplatesAction action = new RestMLListContextManagementTemplatesAction(mlFeatureEnabledSetting); + assertNotNull(action); + } + + public void testGetName() { + String actionName = restAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_list_context_management_templates_action", actionName); + } + + public void testRoutes() { + List routes = restAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.GET, route.getMethod()); + assertEquals("/_plugins/_ml/context_management", route.getPath()); + } + + public void testPrepareRequest() throws Exception { + RestRequest request = getRestRequest(); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesRequest.class); + verify(client, times(1)).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), argumentCaptor.capture(), any()); + MLListContextManagementTemplatesRequest capturedRequest = argumentCaptor.getValue(); + assertEquals(0, capturedRequest.getFrom()); + assertEquals(10, capturedRequest.getSize()); + } + + public void testPrepareRequestWithCustomPagination() throws Exception { + RestRequest request = getRestRequestWithPagination(5, 20); + restAction.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesRequest.class); + verify(client, times(1)).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), argumentCaptor.capture(), any()); + MLListContextManagementTemplatesRequest capturedRequest = argumentCaptor.getValue(); + assertEquals(5, capturedRequest.getFrom()); + assertEquals(20, capturedRequest.getSize()); + } + + public void testPrepareRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } + + public void testGetRequestWithDefaultPagination() throws Exception { + RestRequest request = getRestRequest(); + MLListContextManagementTemplatesRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals(0, result.getFrom()); + assertEquals(10, result.getSize()); + } + + public void testGetRequestWithCustomPagination() throws Exception { + RestRequest request = getRestRequestWithPagination(15, 25); + MLListContextManagementTemplatesRequest result = restAction.getRequest(request); + + assertNotNull(result); + assertEquals(15, result.getFrom()); + assertEquals(25, result.getSize()); + } + + public void testGetRequestWithAgentFrameworkDisabled() { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + + IllegalStateException exception = assertThrows(IllegalStateException.class, () -> restAction.getRequest(request)); + assertEquals("Agent framework is disabled", exception.getMessage()); + } + + public void testGetRequestWithInvalidPagination() throws Exception { + RestRequest request = getRestRequestWithPagination(-1, -5); + MLListContextManagementTemplatesRequest result = restAction.getRequest(request); + + assertNotNull(result); + // The REST action passes through the parameters as-is, validation happens at the service level + assertEquals(-1, result.getFrom()); + assertEquals(-5, result.getSize()); + } + + public void testPrepareRequestReturnsRestChannelConsumer() throws Exception { + RestRequest request = getRestRequest(); + Object consumer = restAction.prepareRequest(request, client); + + assertNotNull(consumer); + + // Execute the consumer to test the actual execution path using reflection + java.lang.reflect.Method acceptMethod = consumer.getClass().getMethod("accept", Object.class); + acceptMethod.invoke(consumer, channel); + + ArgumentCaptor argumentCaptor = ArgumentCaptor + .forClass(MLListContextManagementTemplatesRequest.class); + verify(client, times(1)).execute(eq(MLListContextManagementTemplatesAction.INSTANCE), argumentCaptor.capture(), any()); + MLListContextManagementTemplatesRequest capturedRequest = argumentCaptor.getValue(); + assertEquals(0, capturedRequest.getFrom()); + assertEquals(10, capturedRequest.getSize()); + } + + private RestRequest getRestRequest() { + Map params = new HashMap<>(); + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } + + private RestRequest getRestRequestWithPagination(int from, int size) { + Map params = new HashMap<>(); + params.put("from", String.valueOf(from)); + params.put("size", String.valueOf(size)); + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 1f53744661..d1d0eca084 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -33,6 +33,8 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.action.contextmanagement.ContextManagementTemplateService; +import org.opensearch.ml.action.contextmanagement.ContextManagerFactory; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; @@ -78,6 +80,10 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase { DiscoveryNodeHelper nodeHelper; @Mock ClusterApplierService clusterApplierService; + @Mock + ContextManagementTemplateService contextManagementTemplateService; + @Mock + ContextManagerFactory contextManagerFactory; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -138,7 +144,9 @@ public void setup() { mlTaskDispatcher, mlCircuitBreakerService, nodeHelper, - mlEngine + mlEngine, + contextManagementTemplateService, + contextManagerFactory ) ); From 40c84e98b93d757a20ebdbd011457bcefee2ac61 Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Tue, 4 Nov 2025 17:05:40 -0800 Subject: [PATCH 19/58] fix per body template for new execute api (per only supports text input atm) Signed-off-by: Pavan Yekbote --- .../agent/MLPlanExecuteAndReflectAgentRunner.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 9ad4bed833..b989eb3456 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -162,6 +162,9 @@ public class MLPlanExecuteAndReflectAgentRunner implements MLAgentRunner { public static final String INJECT_DATETIME_FIELD = "inject_datetime"; public static final String DATETIME_FORMAT_FIELD = "datetime_format"; + private static final String BODY_FIELD = "body"; + private static final String BODY_TEMPLATE = "{\"role\":\"user\",\"content\":[{\"text\":\"${parameters.prompt}\"}]}"; + public MLPlanExecuteAndReflectAgentRunner( Client client, Settings settings, @@ -194,6 +197,10 @@ void setupPromptParameters(Map params) { // populated depending on whether LLM is asked to plan or re-evaluate // removed here, so that error is thrown in case this field is not populated params.remove(PROMPT_FIELD); + // workaround for agent revamp until PER supports messages + if (params.containsKey(BODY_FIELD)) { + params.put(BODY_FIELD, BODY_TEMPLATE); + } String userPrompt = params.get(QUESTION_FIELD); params.put(USER_PROMPT_FIELD, userPrompt); From 82f34444666aa6e5fd765f026656bcc25d50e9b8 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Tue, 4 Nov 2025 23:36:06 -0800 Subject: [PATCH 20/58] added multi-tenancy to the add memory and search memory action --- .../ml/common/settings/MLCommonsSettings.java | 2 +- .../memory/MLAddMemoriesInput.java | 16 ++++++++++++++-- .../memory/MLGetMemoryRequest.java | 7 ++++++- .../memory/TransportAddMemoriesAction.java | 2 +- .../memory/TransportGetMemoryAction.java | 3 ++- .../ml/helper/MemoryContainerHelper.java | 7 +++++-- .../ml/rest/RestMLAddMemoriesAction.java | 5 ++++- .../ml/rest/RestMLGetMemoryAction.java | 4 +++- .../memory/TransportGetMemoryActionTests.java | 14 +++++++------- 9 files changed, 43 insertions(+), 17 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java index c139ea4b68..416bcbe1f3 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java @@ -431,7 +431,7 @@ private MLCommonsSettings() {} * After setting this option, a full cluster restart is required for the changes to take effect. */ public static final Setting ML_COMMONS_MULTI_TENANCY_ENABLED = Setting - .boolSetting(ML_PLUGIN_SETTING_PREFIX + "multi_tenancy_enabled", false, Setting.Property.NodeScope); + .boolSetting(ML_PLUGIN_SETTING_PREFIX + "multi_tenancy_enabled", true, Setting.Property.NodeScope); /** This setting sets the remote metadata type */ public static final Setting REMOTE_METADATA_TYPE = Setting diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesInput.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesInput.java index 350b3594d5..103cbbab7b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLAddMemoriesInput.java @@ -6,6 +6,7 @@ package org.opensearch.ml.common.transport.memorycontainer.memory; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.AGENT_ID_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.BINARY_DATA_FIELD; import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.CHECKPOINT_ID_FIELD; @@ -72,6 +73,7 @@ public class MLAddMemoriesInput implements ToXContentObject, Writeable { // Checkpoint field private String checkpointId; + private String tenantId; public MLAddMemoriesInput( String memoryContainerId, @@ -86,7 +88,8 @@ public MLAddMemoriesInput( Map tags, Map parameters, String ownerId, - String checkpointId + String checkpointId, + String tenantId ) { // MAX_MESSAGES_PER_REQUEST limit removed for performance testing @@ -106,6 +109,7 @@ public MLAddMemoriesInput( } this.ownerId = ownerId; this.checkpointId = checkpointId; + this.tenantId = tenantId; validate(); } @@ -151,6 +155,7 @@ public MLAddMemoriesInput(StreamInput in) throws IOException { } this.ownerId = in.readOptionalString(); this.checkpointId = in.readOptionalString(); + this.tenantId = in.readOptionalString(); } @Override @@ -201,6 +206,7 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalString(ownerId); out.writeOptionalString(checkpointId); + out.writeOptionalString(tenantId); } @Override @@ -250,6 +256,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (checkpointId != null) { builder.field(CHECKPOINT_ID_FIELD, checkpointId); } + if (tenantId != null) { + builder.field(TENANT_ID_FIELD, tenantId); + } if (withTimeStamp) { Instant now = Instant.now(); builder.field(CREATED_TIME_FIELD, now.toEpochMilli()); @@ -259,7 +268,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par return builder; } - public static MLAddMemoriesInput parse(XContentParser parser, String memoryContainerId) throws IOException { + public static MLAddMemoriesInput parse(XContentParser parser, String memoryContainerId, String tenantId) throws IOException { String payloadType = null; List messages = null; Integer messageId = null; @@ -322,6 +331,8 @@ public static MLAddMemoriesInput parse(XContentParser parser, String memoryConta case CHECKPOINT_ID_FIELD: checkpointId = parser.text(); break; + case TENANT_ID_FIELD: + tenantId = parser.textOrNull(); default: parser.skipChildren(); break; @@ -343,6 +354,7 @@ public static MLAddMemoriesInput parse(XContentParser parser, String memoryConta .parameters(parameters) .ownerId(ownerId) .checkpointId(checkpointId) + .tenantId(tenantId) .build(); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequest.java index 2fbd3aff02..8ad192d575 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequest.java @@ -34,12 +34,15 @@ public class MLGetMemoryRequest extends ActionRequest { String memoryContainerId; MemoryType memoryType; String memoryId; + String tenantId; + @Builder - public MLGetMemoryRequest(String memoryContainerId, MemoryType memoryType, String memoryId) { + public MLGetMemoryRequest(String memoryContainerId, MemoryType memoryType, String memoryId, String tenantId) { this.memoryContainerId = memoryContainerId; this.memoryType = memoryType; this.memoryId = memoryId; + this.tenantId = tenantId; } public MLGetMemoryRequest(StreamInput in) throws IOException { @@ -47,6 +50,7 @@ public MLGetMemoryRequest(StreamInput in) throws IOException { this.memoryContainerId = in.readString(); this.memoryType = in.readEnum(MemoryType.class); this.memoryId = in.readString(); + this.tenantId = in.readString(); } @Override @@ -55,6 +59,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(this.memoryContainerId); out.writeEnum(this.memoryType); out.writeString(this.memoryId); + out.writeString(this.tenantId); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java index d26919f36c..7b98277512 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java @@ -124,7 +124,7 @@ protected void doExecute(Task task, MLAddMemoriesRequest request, ActionListener return; } - memoryContainerHelper.getMemoryContainer(memoryContainerId, ActionListener.wrap(container -> { + memoryContainerHelper.getMemoryContainer(memoryContainerId, request.getMlAddMemoryInput().getTenantId(), ActionListener.wrap(container -> { if (!memoryContainerHelper.checkMemoryContainerAccess(user, container)) { actionListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryAction.java index c8b093af66..d8b003505f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportGetMemoryAction.java @@ -69,9 +69,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + memoryContainerHelper.getMemoryContainer(memoryContainerId, tenantId, ActionListener.wrap(container -> { // Validate access permissions User user = RestActionUtils.getUserContext(client); if (!memoryContainerHelper.checkMemoryContainerAccess(user, container)) { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java index 7047f2a37d..e795a290bb 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java @@ -127,6 +127,8 @@ public void getMemoryContainer(String memoryContainerId, ActionListener { @@ -215,7 +215,7 @@ public void testDoExecuteSuccess() { public void testDoExecuteWithUnauthorizedUser() { // Setup request - MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID); + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID, null); // Setup memory container helper to return container doAnswer(invocation -> { @@ -250,7 +250,7 @@ public void testDoExecuteWithUnauthorizedUser() { public void testDoExecuteWithParsingException() { // Setup request - MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID); + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID, null); // Setup memory container helper to return container doAnswer(invocation -> { @@ -299,7 +299,7 @@ public void testDoExecuteWithParsingException() { @Test public void testDoExecuteWithNoResponse() { // Setup request - MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID); + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID, null); // Setup memory container helper to return container doAnswer(invocation -> { @@ -343,7 +343,7 @@ public void testDoExecuteWithNoResponse() { @Test public void testDoExecuteWithClientGetFailure() { // Setup request - MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID); + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID, null); // Setup memory container helper to return container doAnswer(invocation -> { @@ -383,7 +383,7 @@ public void testDoExecuteWithClientGetFailure() { @Test public void testDoExecuteWithProcessResponseException() { // Setup request - MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID); + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID, null); // Setup memory container helper to return container doAnswer(invocation -> { @@ -442,7 +442,7 @@ public void testDoExecuteWithFeatureDisabled() { when(mlFeatureEnabledSetting.isAgenticMemoryEnabled()).thenReturn(false); // Setup request - MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID); + MLGetMemoryRequest getRequest = new MLGetMemoryRequest(MEMORY_CONTAINER_ID, MEMORY_TYPE, MEMORY_ID, null); // Execute action.doExecute(task, getRequest, actionListener); From 22e5ce48eb9cd108ba31df1a924f70a94fc09796 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Wed, 5 Nov 2025 12:14:21 -0800 Subject: [PATCH 21/58] added client.search instead of sdkclient for collection index Signed-off-by: Dhrubo Saha --- .../memory/MLGetMemoryRequest.java | 1 - .../memory/MemorySearchService.java | 11 ++----- .../memory/TransportAddMemoriesAction.java | 24 ++++++++------ .../memory/TransportSearchMemoriesAction.java | 21 +++++-------- .../ml/helper/MemoryContainerHelper.java | 31 +++++-------------- .../ml/rest/RestMLAddMemoriesAction.java | 1 - 6 files changed, 33 insertions(+), 56 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequest.java index 8ad192d575..f826952aba 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLGetMemoryRequest.java @@ -36,7 +36,6 @@ public class MLGetMemoryRequest extends ActionRequest { String memoryId; String tenantId; - @Builder public MLGetMemoryRequest(String memoryContainerId, MemoryType memoryType, String memoryId, String tenantId) { this.memoryContainerId = memoryContainerId; diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java index f00ad80cd1..600a58697b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemorySearchService.java @@ -11,6 +11,7 @@ import java.util.List; import java.util.Map; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.index.query.QueryBuilder; @@ -19,7 +20,6 @@ import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesInput; import org.opensearch.ml.helper.MemoryContainerHelper; import org.opensearch.ml.utils.MemorySearchQueryBuilder; -import org.opensearch.remote.metadata.client.SearchDataObjectRequest; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -111,13 +111,8 @@ private void searchFactsSequentially( searchSourceBuilder.size(maxInferSize); searchSourceBuilder.fetchSource(new String[] { MEMORY_FIELD }, null); - SearchDataObjectRequest searchRequest = SearchDataObjectRequest - .builder() - .indices(indexName) - .searchSourceBuilder(searchSourceBuilder) - .tenantId(tenantId) - .build(); - // TODO: add search pipeline support in SearchDataObjectRequest + SearchRequest searchRequest = new SearchRequest(indexName).source(searchSourceBuilder); + // TODO: add search pipeline support in SearchRequest // if (memoryConfig.getSearchPipeline() != null) { // searchRequest.pipeline(memoryConfig.getSearchPipeline()); // } diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java index 7b98277512..f55a2d0a2a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java @@ -124,16 +124,20 @@ protected void doExecute(Task task, MLAddMemoriesRequest request, ActionListener return; } - memoryContainerHelper.getMemoryContainer(memoryContainerId, request.getMlAddMemoryInput().getTenantId(), ActionListener.wrap(container -> { - if (!memoryContainerHelper.checkMemoryContainerAccess(user, container)) { - actionListener - .onFailure( - new OpenSearchStatusException("User doesn't have permissions to add memory to this container", RestStatus.FORBIDDEN) - ); - return; - } - createNewSessionIfAbsent(input, container, user, actionListener); - }, actionListener::onFailure)); + memoryContainerHelper + .getMemoryContainer(memoryContainerId, request.getMlAddMemoryInput().getTenantId(), ActionListener.wrap(container -> { + if (!memoryContainerHelper.checkMemoryContainerAccess(user, container)) { + actionListener + .onFailure( + new OpenSearchStatusException( + "User doesn't have permissions to add memory to this container", + RestStatus.FORBIDDEN + ) + ); + return; + } + createNewSessionIfAbsent(input, container, user, actionListener); + }, actionListener::onFailure)); } private void createNewSessionIfAbsent( diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java index dae30eda32..f20c03e6e7 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java @@ -10,6 +10,7 @@ import org.apache.commons.lang3.StringUtils; import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -28,7 +29,6 @@ import org.opensearch.ml.helper.MemoryContainerHelper; import org.opensearch.ml.utils.RestActionUtils; import org.opensearch.ml.utils.TenantAwareHelper; -import org.opensearch.remote.metadata.client.SearchDataObjectRequest; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; @@ -118,19 +118,14 @@ private void searchMemories( memoryContainerHelper.addContainerIdFilter(input.getMemoryContainerId(), input.getSearchSourceBuilder()); // Add owner filter for non-admin users - if (!memoryContainerHelper.isAdminUser(user)) { - memoryContainerHelper.addOwnerIdFilter(user, input.getSearchSourceBuilder()); - } + // if (!memoryContainerHelper.isAdminUser(user)) { + // memoryContainerHelper.addOwnerIdFilter(user, input.getSearchSourceBuilder()); + // } - SearchDataObjectRequest searchDataObjecRequest = SearchDataObjectRequest - .builder() - .indices(indexName) - .searchSourceBuilder(input.getSearchSourceBuilder()) - .tenantId(tenantId) - .build(); - // TODO: add search pipeline support in SearchDataObjectRequest + SearchRequest searchRequest = new SearchRequest(indexName).source(input.getSearchSourceBuilder()); + // TODO: add search pipeline support in SearchRequest // if (memoryConfig.getSearchPipeline() != null) { - // searchDataObjecRequest.pipeline(memoryConfig.getSearchPipeline()); + // searchRequest.pipeline(memoryConfig.getSearchPipeline()); // } // Execute search @@ -147,7 +142,7 @@ private void searchMemories( }); if (memoryConfig.getRemoteStore() == null) { - memoryContainerHelper.searchData(container.getConfiguration(), searchDataObjecRequest, searchResponseActionListener); + memoryContainerHelper.searchData(container.getConfiguration(), searchRequest, searchResponseActionListener); } else { String query = input.getSearchSourceBuilder().toString(); memoryContainerHelper.searchDataFromRemoteStorage(memoryConfig, indexName, query, searchResponseActionListener); diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java index e795a290bb..cd9a7064f1 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java @@ -37,6 +37,7 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.clustermanager.AcknowledgedResponse; import org.opensearch.action.update.UpdateRequest; @@ -71,7 +72,6 @@ import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.remote.metadata.client.GetDataObjectRequest; import org.opensearch.remote.metadata.client.SdkClient; -import org.opensearch.remote.metadata.client.SearchDataObjectRequest; import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; @@ -127,8 +127,6 @@ public void getMemoryContainer(String memoryContainerId, ActionListener listener - ) { + public void searchData(MemoryConfiguration configuration, SearchRequest searchRequest, ActionListener listener) { try { // Check if remote store is configured (either with connectorId or internal connector) if (configuration.getRemoteStore() != null @@ -305,13 +299,13 @@ public void searchData( final ActionListener doubleWrappedListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); - sdkClient.searchDataObjectAsync(searchRequest).whenComplete(SdkClientUtils.wrapSearchCompletion(doubleWrappedListener)); + client.search(searchRequest, doubleWrappedListener); } } else { final ActionListener doubleWrappedListener = ActionListener .wrap(listener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, listener)); - sdkClient.searchDataObjectAsync(searchRequest).whenComplete(SdkClientUtils.wrapSearchCompletion(doubleWrappedListener)); + client.search(searchRequest, doubleWrappedListener); } } catch (Exception e) { log.error("Failed to search data", e); @@ -771,7 +765,7 @@ public SearchSourceBuilder addContainerIdFilter(String containerId, SearchSource return searchSourceBuilder; } BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); - boolQueryBuilder.filter(QueryBuilders.termQuery(MEMORY_CONTAINER_ID_FIELD, containerId)); + boolQueryBuilder.filter(QueryBuilders.termQuery(MEMORY_CONTAINER_ID_FIELD + ".keyword", containerId)); return applyFilterToSearchSource(searchSourceBuilder, boolQueryBuilder); } @@ -792,7 +786,7 @@ public QueryBuilder addContainerIdFilter(String containerId, QueryBuilder query) BoolQueryBuilder filteredQuery = QueryBuilders.boolQuery(); filteredQuery.must(query); - filteredQuery.filter(QueryBuilders.termQuery(MEMORY_CONTAINER_ID_FIELD, containerId)); + filteredQuery.filter(QueryBuilders.termQuery(MEMORY_CONTAINER_ID_FIELD + ".keyword", containerId)); return filteredQuery; } @@ -820,16 +814,7 @@ public void countContainersWithPrefix(String indexPrefix, String tenantId, Actio searchSourceBuilder.size(0); // We only need the total count searchSourceBuilder.trackTotalHits(true); - SearchDataObjectRequest.Builder requestBuilder = SearchDataObjectRequest - .builder() - .indices(ML_MEMORY_CONTAINER_INDEX) - .searchSourceBuilder(searchSourceBuilder); - - if (tenantId != null) { - requestBuilder.tenantId(tenantId); - } - - SearchDataObjectRequest searchRequest = requestBuilder.build(); + SearchRequest searchRequest = new SearchRequest(ML_MEMORY_CONTAINER_INDEX).source(searchSourceBuilder); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(ActionListener.wrap(response -> { @@ -840,7 +825,7 @@ public void countContainersWithPrefix(String indexPrefix, String tenantId, Actio listener.onFailure(e); }), context::restore); - sdkClient.searchDataObjectAsync(searchRequest).whenComplete(SdkClientUtils.wrapSearchCompletion(wrappedListener)); + client.search(searchRequest, wrappedListener); } catch (Exception e) { log.error("Failed to search for containers with prefix: " + indexPrefix, e); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLAddMemoriesAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLAddMemoriesAction.java index 0025b8c63a..3b4f9d6f69 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLAddMemoriesAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLAddMemoriesAction.java @@ -76,7 +76,6 @@ private MLAddMemoriesRequest getRequest(RestRequest request) throws IOException String tenantId = TenantAwareHelper.getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); MLAddMemoriesInput mlAddMemoryInput = MLAddMemoriesInput.parse(parser, memoryContainerId, tenantId); - return new MLAddMemoriesRequest(mlAddMemoryInput); } } From d3bdf82747e0dd4f04483a9b003c69fb4a2bac93 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Wed, 5 Nov 2025 15:08:13 -0800 Subject: [PATCH 22/58] adding tenancy to the create session --- .../ml/action/session/TransportCreateSessionAction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/session/TransportCreateSessionAction.java b/plugin/src/main/java/org/opensearch/ml/action/session/TransportCreateSessionAction.java index 7869ff16b5..5ecd267e46 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/session/TransportCreateSessionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/session/TransportCreateSessionAction.java @@ -88,7 +88,7 @@ protected void doExecute(Task task, MLCreateSessionRequest request, ActionListen return; } - memoryContainerHelper.getMemoryContainer(memoryContainerId, ActionListener.wrap(container -> { + memoryContainerHelper.getMemoryContainer(memoryContainerId, tenantId, ActionListener.wrap(container -> { if (!memoryContainerHelper.checkMemoryContainerAccess(user, container)) { actionListener .onFailure( From 27297632d8a9fc25becb73d4f3fcd2637810f6f4 Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Wed, 5 Nov 2025 17:17:58 -0800 Subject: [PATCH 23/58] Allow context management inline create in register agent without storing in index (#4403) * allow inline create context management without storing in agent register Signed-off-by: Mingshi Liu * make ML_COMMONS_MULTI_TENANCY_ENABLED default is false Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu --- .../opensearch/ml/common/agent/MLAgent.java | 8 +++++++ .../ContextManagerHookProvider.java | 14 +++++++---- .../ml/common/settings/MLCommonsSettings.java | 2 +- .../algorithms/agent/MLAgentExecutor.java | 14 +++++++++++ .../MLPlanExecuteAndReflectAgentRunner.java | 1 + .../algorithms/agent/MLAgentExecutorTest.java | 17 +++++++++++++ .../ContextManagerFactoryTests.java | 24 +++++++++++++++++++ 7 files changed, 75 insertions(+), 5 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index 06a4d69ebe..52c7826c45 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -519,6 +519,14 @@ public boolean hasContextManagementTemplate() { return contextManagementName != null; } + /** + * Check if this agent has inline context management configuration + * @return true if agent has inline context management configuration + */ + public boolean hasInlineContextManagement() { + return contextManagement != null; + } + /** * Get the context management template name if this agent references one * @return the template name, or null if no template reference diff --git a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java index 35109c53dd..dfef018c87 100644 --- a/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java +++ b/common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerHookProvider.java @@ -47,10 +47,16 @@ public ContextManagerHookProvider(List contextManagers) { */ @Override public void registerHooks(HookRegistry registry) { - // Register callbacks for each hook type - registry.addCallback(PreLLMEvent.class, this::handlePreLLM); - registry.addCallback(EnhancedPostToolEvent.class, this::handlePostTool); - registry.addCallback(PostMemoryEvent.class, this::handlePostMemory); + // Only register callbacks for hooks that have managers configured + if (hookToManagersMap.containsKey("PRE_LLM")) { + registry.addCallback(PreLLMEvent.class, this::handlePreLLM); + } + if (hookToManagersMap.containsKey("POST_TOOL")) { + registry.addCallback(EnhancedPostToolEvent.class, this::handlePostTool); + } + if (hookToManagersMap.containsKey("POST_MEMORY")) { + registry.addCallback(PostMemoryEvent.class, this::handlePostMemory); + } log.info("Registered context manager hooks for {} managers", contextManagers.size()); } diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java index 416bcbe1f3..c139ea4b68 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java @@ -431,7 +431,7 @@ private MLCommonsSettings() {} * After setting this option, a full cluster restart is required for the changes to take effect. */ public static final Setting ML_COMMONS_MULTI_TENANCY_ENABLED = Setting - .boolSetting(ML_PLUGIN_SETTING_PREFIX + "multi_tenancy_enabled", true, Setting.Property.NodeScope); + .boolSetting(ML_PLUGIN_SETTING_PREFIX + "multi_tenancy_enabled", false, Setting.Property.NodeScope); /** This setting sets the remote metadata type */ public static final Setting REMOTE_METADATA_TYPE = Setting diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index bb37abe940..87b26b2957 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -509,6 +509,20 @@ private void processContextManagement(MLAgent mlAgent, HookRegistry hookRegistry return; // Runtime parameter takes precedence, let MLExecuteTaskRunner handle it } + // Check if already processed to avoid duplicate registrations + if ("true".equals(inputDataSet.getParameters().get("context_management_processed"))) { + log.debug("Context management already processed for this execution, skipping"); + return; + } + + // Check if HookRegistry already has callbacks (from previous runtime setup) + // Don't override with inline configuration if runtime config is already active + if (hookRegistry.getCallbackCount(org.opensearch.ml.common.hooks.EnhancedPostToolEvent.class) > 0 + || hookRegistry.getCallbackCount(org.opensearch.ml.common.hooks.PreLLMEvent.class) > 0) { + log.info("HookRegistry already has active configuration, skipping inline context management"); + return; + } + ContextManagementTemplate template = null; String templateName = null; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index b989eb3456..9362448300 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -499,6 +499,7 @@ private void executePlanningLoop( .build(); // Pass hookRegistry to internal agent execution + // TODO need to check if the agentInput already have the hookResgistry? agentInput.setHookRegistry(hookRegistry); MLExecuteTaskRequest executeRequest = new MLExecuteTaskRequest(FunctionName.AGENT, agentInput); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java index 2e8c612c43..a641753cdb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java @@ -301,4 +301,21 @@ private MLAgent createTestAgent(String type) { .appType("test-app") .build(); } + + @Test + public void testContextManagementProcessedFlagPreventsReprocessing() { + // Test that the context_management_processed flag prevents duplicate processing + Map parameters = new HashMap<>(); + + // First check - should allow processing + boolean shouldProcess1 = !"true".equals(parameters.get("context_management_processed")); + assertTrue("First call should allow processing", shouldProcess1); + + // Mark as processed (simulating what the method does) + parameters.put("context_management_processed", "true"); + + // Second check - should prevent processing + boolean shouldProcess2 = !"true".equals(parameters.get("context_management_processed")); + assertFalse("Second call should prevent processing", shouldProcess2); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java index 1e0661c80b..f196da28d9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/contextmanagement/ContextManagerFactoryTests.java @@ -140,4 +140,28 @@ public void testCreateContextManager_EmptyType() { assertTrue(e.getMessage().contains("Unsupported context manager type")); } } + + @Test + public void testContextManagerHookProvider_SelectiveRegistration() { + // Test that ContextManagerHookProvider only registers hooks for configured managers + java.util.Map> hookToManagersMap = new java.util.HashMap<>(); + + // Test 1: Only POST_TOOL configured + hookToManagersMap.put("POST_TOOL", java.util.Arrays.asList("ToolsOutputTruncateManager")); + + // Simulate the registration logic + java.util.Set registeredHooks = new java.util.HashSet<>(); + if (hookToManagersMap.containsKey("PRE_LLM")) { + registeredHooks.add("PRE_LLM"); + } + if (hookToManagersMap.containsKey("POST_TOOL")) { + registeredHooks.add("POST_TOOL"); + } + if (hookToManagersMap.containsKey("POST_MEMORY")) { + registeredHooks.add("POST_MEMORY"); + } + + // Assert only POST_TOOL is registered + assertTrue("Should only register POST_TOOL hook", registeredHooks.size() == 1 && registeredHooks.contains("POST_TOOL")); + } } From 51d5b470ef1d723a0d7c23c9be6385ab66ca2d53 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Wed, 5 Nov 2025 19:21:03 -0800 Subject: [PATCH 24/58] execute PER agent with multi-tenancy Signed-off-by: Dhrubo Saha --- .../memory/MLUpdateMemoryRequest.java | 7 +++- .../engine/algorithms/agent/AgentUtils.java | 2 + .../MLPlanExecuteAndReflectAgentRunner.java | 9 +++-- .../ml/engine/indices/MLIndicesHandler.java | 2 +- .../memory/AgenticConversationMemory.java | 40 ++++++++++++------- .../memory/AgenticConversationMemoryTest.java | 9 +++-- .../memory/TransportUpdateMemoryAction.java | 3 +- .../ml/helper/MemoryContainerHelper.java | 4 +- .../ml/rest/RestMLUpdateMemoryAction.java | 3 ++ 9 files changed, 53 insertions(+), 26 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryRequest.java index 68f0690a47..73ba3779fb 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/memorycontainer/memory/MLUpdateMemoryRequest.java @@ -31,18 +31,21 @@ public class MLUpdateMemoryRequest extends ActionRequest { private String memoryContainerId; private MemoryType memoryType; private String memoryId; + private String tenantId; @Builder public MLUpdateMemoryRequest( String memoryContainerId, MemoryType memoryType, String memoryId, - MLUpdateMemoryInput mlUpdateMemoryInput + MLUpdateMemoryInput mlUpdateMemoryInput, + String tenantId ) { this.memoryContainerId = memoryContainerId; this.memoryType = memoryType; this.memoryId = memoryId; this.mlUpdateMemoryInput = mlUpdateMemoryInput; + this.tenantId = tenantId; } public MLUpdateMemoryRequest(StreamInput in) throws IOException { @@ -51,6 +54,7 @@ public MLUpdateMemoryRequest(StreamInput in) throws IOException { this.memoryType = in.readEnum(MemoryType.class); this.memoryId = in.readString(); this.mlUpdateMemoryInput = new MLUpdateMemoryInput(in); + this.tenantId = in.readOptionalString(); } @Override @@ -60,6 +64,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeEnum(memoryType); out.writeString(memoryId); mlUpdateMemoryInput.writeTo(out); + out.writeOptionalString(tenantId); } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index cbc003388f..97d45a4129 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -30,6 +30,7 @@ import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES; import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.RESPONSE_FIELD; +import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.TENANT_ID_FIELD; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS; @@ -1031,6 +1032,7 @@ public static Map createMemoryParams( memoryParams.put(APP_TYPE, appType); if (mlAgent.getMemory().getMemoryContainerId() != null) { memoryParams.put(MEMORY_CONTAINER_ID_FIELD, mlAgent.getMemory().getMemoryContainerId()); + memoryParams.put(TENANT_ID_FIELD, mlAgent.getTenantId()); } memoryParams.putIfAbsent(MEMORY_CONTAINER_ID_FIELD, memoryContainerId); return memoryParams; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 9362448300..776d85884c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -359,7 +359,7 @@ private void setToolsAndRunAgent( AtomicInteger traceNumber = new AtomicInteger(0); - executePlanningLoop(mlAgent.getLlm(), allParams, completedSteps, memory, conversationId, 0, traceNumber, finalListener); + executePlanningLoop(mlAgent.getLlm(), allParams, completedSteps, memory, conversationId, 0, traceNumber, mlAgent.getTenantId(), finalListener); }; // Fetch MCP tools and handle both success and failure cases @@ -380,6 +380,7 @@ private void executePlanningLoop( String conversationId, int stepsExecuted, AtomicInteger traceNumber, + String tenantId, ActionListener finalListener ) { int maxSteps = Integer.parseInt(allParams.getOrDefault(MAX_STEPS_EXECUTED_FIELD, DEFAULT_MAX_STEPS_EXECUTED)); @@ -398,7 +399,7 @@ private void executePlanningLoop( completedSteps.getLast() ); saveAndReturnFinalResult( - (ConversationIndexMemory) memory, + memory, parentInteractionId, allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD), allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD), @@ -450,7 +451,7 @@ private void executePlanningLoop( .inputDataset(RemoteInferenceInputDataSet.builder().parameters(allParams).build()) .build(), null, - allParams.get(TENANT_ID_FIELD) + tenantId ); StepListener planListener = new StepListener<>(); @@ -496,6 +497,7 @@ private void executePlanningLoop( .agentId(reActAgentId) .functionName(FunctionName.AGENT) .inputDataset(RemoteInferenceInputDataSet.builder().parameters(reactParams).build()) + .tenantId(tenantId) .build(); // Pass hookRegistry to internal agent execution @@ -589,6 +591,7 @@ private void executePlanningLoop( conversationId, stepsExecuted + 1, traceNumber, + tenantId, finalListener ); }, e -> { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index 2b7e428ab6..4eb567802c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -97,7 +97,7 @@ public class MLIndicesHandler { * pre-populated. If this is incorrect, it will result in unwanted early returns without checking the clusterService. */ public static boolean doesMultiTenantIndexExist(ClusterService clusterService, boolean isMultiTenancyEnabled, String indexName) { - return isMultiTenancyEnabled || clusterService.state().metadata().hasIndex(indexName); + return (indexName.startsWith(".") && isMultiTenancyEnabled) || clusterService.state().metadata().hasIndex(indexName); } public void initModelGroupIndexIfAbsent(ActionListener listener) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java index 3b27275afe..7b8361f184 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.memory; +import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.TENANT_ID_FIELD; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; @@ -64,11 +65,13 @@ public class AgenticConversationMemory implements Memory updateContent, ActionLi } // Use retry mechanism for AOSS compatibility (high refresh latency) - updateWithRetry(messageId, updateContent, updateListener, 0); + updateWithRetry(messageId, updateContent, updateListener, 0, tenantId); } /** @@ -194,7 +198,8 @@ private void updateWithRetry( String messageId, Map updateContent, ActionListener updateListener, - int attemptNumber + int attemptNumber, + String tenantId ) { final int maxRetries = 5; final long baseDelayMs = 500; @@ -205,6 +210,7 @@ private void updateWithRetry( .memoryContainerId(memoryContainerId) .memoryType(MemoryType.WORKING) .memoryId(messageId) + .tenantId(tenantId) .build(); client.execute(MLGetMemoryAction.INSTANCE, getRequest, ActionListener.wrap(getResponse -> { @@ -245,6 +251,7 @@ private void updateWithRetry( .memoryType(MemoryType.WORKING) .memoryId(messageId) .mlUpdateMemoryInput(input) + .tenantId(tenantId) .build(); // Step 5: Execute the update @@ -291,7 +298,7 @@ private void updateWithRetry( } // Retry - updateWithRetry(messageId, updateContent, updateListener, attemptNumber + 1); + updateWithRetry(messageId, updateContent, updateListener, attemptNumber + 1, tenantId); } else { if (attemptNumber >= maxRetries) { log.error("Failed to get existing memory after {} retries. MessageId: {}", maxRetries, messageId, e); @@ -328,7 +335,7 @@ public void getMessages(int size, ActionListener> listener) { .searchSourceBuilder(searchSourceBuilder) .build(); - MLSearchMemoriesRequest request = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(searchInput).tenantId(null).build(); + MLSearchMemoriesRequest request = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(searchInput).tenantId(tenantId).build(); client.execute(MLSearchMemoriesAction.INSTANCE, request, ActionListener.wrap(searchResponse -> { List interactions = parseSearchResponseToInteractions(searchResponse); @@ -456,7 +463,7 @@ public void getTraces(String parentMessageId, ActionListener> .searchSourceBuilder(searchSourceBuilder) .build(); - MLSearchMemoriesRequest request = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(searchInput).tenantId(null).build(); + MLSearchMemoriesRequest request = MLSearchMemoriesRequest.builder().mlSearchMemoriesInput(searchInput).tenantId(tenantId).build(); client.execute(MLSearchMemoriesAction.INSTANCE, request, ActionListener.wrap(searchResponse -> { List traces = parseSearchResponseToTraces(searchResponse); @@ -554,8 +561,9 @@ public void create(Map map, ActionListener listener ) { // Memory container ID is required for AgenticConversationMemory @@ -579,8 +588,8 @@ public void create( if (Strings.isEmpty(memoryId)) { // Create new session using TransportCreateSessionAction - createSessionInMemoryContainer(name, memoryContainerId, ActionListener.wrap(sessionId -> { - create(sessionId, memoryContainerId, listener); + createSessionInMemoryContainer(name, memoryContainerId, tenantId, ActionListener.wrap(sessionId -> { + create(sessionId, memoryContainerId, tenantId, listener); log.debug("Created session in memory container, session id: {}", sessionId); }, e -> { log.error("Failed to create session in memory container", e); @@ -588,15 +597,18 @@ public void create( })); } else { // Use existing session/memory ID - create(memoryId, memoryContainerId, listener); + create(memoryId, memoryContainerId, tenantId, listener); } } /** * Create a new session in the memory container using the new session API */ - private void createSessionInMemoryContainer(String summary, String memoryContainerId, ActionListener listener) { - MLCreateSessionInput input = MLCreateSessionInput.builder().memoryContainerId(memoryContainerId).summary(summary).build(); + private void createSessionInMemoryContainer(String summary, String memoryContainerId, String tenantId, ActionListener listener) { + MLCreateSessionInput input = MLCreateSessionInput.builder(). + memoryContainerId(memoryContainerId). + tenantId(tenantId). + summary(summary).build(); MLCreateSessionRequest request = MLCreateSessionRequest.builder().mlCreateSessionInput(input).build(); @@ -611,8 +623,8 @@ private void createSessionInMemoryContainer(String summary, String memoryContain ); } - public void create(String memoryId, String memoryContainerId, ActionListener listener) { - listener.onResponse(new AgenticConversationMemory(client, memoryId, memoryContainerId)); + public void create(String memoryId, String memoryContainerId, String tenantId, ActionListener listener) { + listener.onResponse(new AgenticConversationMemory(client, memoryId, memoryContainerId, tenantId)); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java index 5d84d01f4b..bb1f86a70b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java @@ -41,7 +41,7 @@ public class AgenticConversationMemoryTest { public void setUp() { MockitoAnnotations.openMocks(this); - agenticMemory = new AgenticConversationMemory(client, "test_conversation_id", "test_memory_container_id"); + agenticMemory = new AgenticConversationMemory(client, "test_conversation_id", "test_memory_container_id", null); } @Test @@ -81,7 +81,7 @@ public void testSaveMessage() { @Test public void testFactoryCreate() { AgenticConversationMemory.Factory factory = new AgenticConversationMemory.Factory(); - factory.init(client, mlIndicesHandler, memoryManager); + factory.init(client); Map params = new HashMap<>(); params.put("memory_id", "test_memory_id"); @@ -99,7 +99,7 @@ public void testFactoryCreate() { @Test public void testFactoryCreateWithNewSession() { AgenticConversationMemory.Factory factory = new AgenticConversationMemory.Factory(); - factory.init(client, mlIndicesHandler, memoryManager); + factory.init(client); // Mock session creation doAnswer(invocation -> { @@ -129,7 +129,8 @@ public void testSaveWithoutMemoryContainerId() { AgenticConversationMemory memoryWithoutContainer = new AgenticConversationMemory( client, "test_conversation_id", - null // No memory container ID = should fail + null, // No memory container ID = should fail, + null ); ConversationIndexMessage message = ConversationIndexMessage diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportUpdateMemoryAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportUpdateMemoryAction.java index 4fe38be24f..d7e7fadcc3 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportUpdateMemoryAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportUpdateMemoryAction.java @@ -81,9 +81,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + memoryContainerHelper.getMemoryContainer(memoryContainerId, tenantId, ActionListener.wrap(container -> { // Validate access permissions User user = RestActionUtils.getUserContext(client); if (!memoryContainerHelper.checkMemoryContainerAccess(user, container)) { diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java index cd9a7064f1..c7cebc3d90 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerHelper.java @@ -765,7 +765,7 @@ public SearchSourceBuilder addContainerIdFilter(String containerId, SearchSource return searchSourceBuilder; } BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); - boolQueryBuilder.filter(QueryBuilders.termQuery(MEMORY_CONTAINER_ID_FIELD + ".keyword", containerId)); + boolQueryBuilder.filter(QueryBuilders.termQuery(MEMORY_CONTAINER_ID_FIELD, containerId)); return applyFilterToSearchSource(searchSourceBuilder, boolQueryBuilder); } @@ -786,7 +786,7 @@ public QueryBuilder addContainerIdFilter(String containerId, QueryBuilder query) BoolQueryBuilder filteredQuery = QueryBuilders.boolQuery(); filteredQuery.must(query); - filteredQuery.filter(QueryBuilders.termQuery(MEMORY_CONTAINER_ID_FIELD + ".keyword", containerId)); + filteredQuery.filter(QueryBuilders.termQuery(MEMORY_CONTAINER_ID_FIELD, containerId)); return filteredQuery; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateMemoryAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateMemoryAction.java index ff755f2119..b58dfd0636 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateMemoryAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateMemoryAction.java @@ -12,6 +12,7 @@ import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.UPDATE_MEMORY_PATH; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE; import static org.opensearch.ml.utils.RestActionUtils.getParameterId; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; import java.io.IOException; import java.util.List; @@ -84,6 +85,7 @@ MLUpdateMemoryRequest getRequest(RestRequest request) throws IOException { XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLUpdateMemoryInput mlUpdateMemoryInput = MLUpdateMemoryInput.parse(parser); + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); return MLUpdateMemoryRequest .builder() @@ -91,6 +93,7 @@ MLUpdateMemoryRequest getRequest(RestRequest request) throws IOException { .memoryType(memoryType) .memoryId(memoryId) .mlUpdateMemoryInput(mlUpdateMemoryInput) + .tenantId(tenantId) .build(); } } From e57a4cd1a232dbfd3a0a12bb06747ecce502677e Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Thu, 6 Nov 2025 13:28:19 -0800 Subject: [PATCH 25/58] add tenant id when encrypt --- .../TransportCreateMemoryContainerAction.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java index 3b04623a12..19ac26f758 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java @@ -147,7 +147,7 @@ protected void doExecute(Task task, MLCreateMemoryContainerRequest request, Acti String tenantId = input.getTenantId(); // Validate configuration before creating memory container - validateConfiguration(input.getConfiguration(), ActionListener.wrap(isValid -> { + validateConfiguration(tenantId, input.getConfiguration(), ActionListener.wrap(isValid -> { // Check if memory container index exists, create if not ActionListener indexCheckListener = ActionListener.wrap(created -> { try { @@ -375,7 +375,7 @@ private void indexMemoryContainer(MLMemoryContainer container, ActionListener listener) { + private void validateConfiguration(String tenantId, MemoryConfiguration config, ActionListener listener) { if (config.getRemoteStore() != null && config.getRemoteStore().getConnector() != null) { if (config.getRemoteStore().getEmbeddingModel() != null && (config.getRemoteStore().getIngestPipeline() == null || config.getRemoteStore().getIngestPipeline().isEmpty())) { @@ -414,7 +414,7 @@ private void validateConfiguration(MemoryConfiguration config, ActionListener { + createInternalConnectorForRemoteStore(tenantId, config.getRemoteStore(), ActionListener.wrap(connector -> { // Set the connector ID in the remote store config config.getRemoteStore().setConnector(connector); @@ -619,7 +619,7 @@ private void createConnectorForRemoteStore(RemoteStore remoteStore, ActionListen } } - private void createInternalConnectorForRemoteStore(RemoteStore remoteStore, ActionListener listener) { + private void createInternalConnectorForRemoteStore(String tenantId, RemoteStore remoteStore, ActionListener listener) { try { String connectorName = "auto_" + remoteStore.getType().name().toLowerCase() @@ -652,7 +652,7 @@ private void createInternalConnectorForRemoteStore(RemoteStore remoteStore, Acti connectorInput.toXContent(builder, ToXContent.EMPTY_PARAMS); Connector connector = Connector.createConnector(builder, connectorInput.getProtocol()); connector.validateConnectorURL(trustedConnectorEndpointsRegex); - connector.encrypt(mlEngine::encrypt, null); + connector.encrypt(mlEngine::encrypt, tenantId); listener.onResponse(connector); } catch (Exception e) { log.error("Error building connector for remote store", e); From 68361cbd04eaac6821fd65c4194e7aee2c3334b3 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Thu, 6 Nov 2025 13:29:55 -0800 Subject: [PATCH 26/58] add tenant id to connector --- .../TransportCreateMemoryContainerAction.java | 1 + .../org/opensearch/ml/helper/RemoteMemoryStoreHelper.java | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java index 19ac26f758..3877b08f69 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java @@ -646,6 +646,7 @@ private void createInternalConnectorForRemoteStore(String tenantId, RemoteStore .parameters(parameters) .credential(credential) .actions(actions) + .tenantId(tenantId) .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); diff --git a/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java index 0e199708c9..9e5c973135 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java @@ -422,8 +422,11 @@ private void runConnector(Connector connector, String actionName, MLInput mlInpu if (connector == null) { throw new IllegalArgumentException("connector is null"); } + //TODO: current we only support internal connector inside memory container in OASIS. The tenant id is same with container's. + // We should check tenant id in future if we use a standalone connector inside memory container. + String connectorTenantId = connector.getTenantId(); // adding tenantID as null, because we are not implement multi-tenancy for this feature yet. - connector.decrypt(actionName, (credential, tenantId) -> encryptor.decrypt(credential, null), null); + connector.decrypt(actionName, (credential, tenantId) -> encryptor.decrypt(credential, tenantId), connectorTenantId); RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); connectorExecutor.setConnectorPrivateIpEnabled(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()); connectorExecutor.setScriptService(scriptService); From ffc38b0d3f664d17d8f17c1e087cedeb2deb3481 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Thu, 6 Nov 2025 13:34:59 -0800 Subject: [PATCH 27/58] add tenant id for get model method Signed-off-by: Yaliang Wu --- .../TransportCreateMemoryContainerAction.java | 15 ++++++++------- .../TransportUpdateMemoryContainerAction.java | 4 +++- .../ml/helper/MemoryContainerModelValidator.java | 8 +++++--- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java index 3877b08f69..c257982762 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerAction.java @@ -393,7 +393,7 @@ private void validateConfiguration(String tenantId, MemoryConfiguration config, config.getRemoteStore().setEmbeddingDimension(embModel.getDimension()); log.info("Auto-created embedding model with ID: {} in remote store", modelId); // Continue with normal validation - validateConfigurationInternal(config, listener); + validateConfigurationInternal(tenantId, config, listener); }, listener::onFailure) ); } else { @@ -405,7 +405,7 @@ private void validateConfiguration(String tenantId, MemoryConfiguration config, ); } // Continue with normal validation - validateConfigurationInternal(config, listener); + validateConfigurationInternal(tenantId, config, listener); } return; } @@ -436,7 +436,7 @@ private void validateConfiguration(String tenantId, MemoryConfiguration config, config.getRemoteStore().setEmbeddingDimension(embModel.getDimension()); log.info("Auto-created embedding model with ID: {} in remote store", modelId); // Continue with normal validation - validateConfigurationInternal(config, listener); + validateConfigurationInternal(tenantId, config, listener); }, listener::onFailure) ); } else { @@ -448,16 +448,16 @@ private void validateConfiguration(String tenantId, MemoryConfiguration config, ); } // Continue with normal validation - validateConfigurationInternal(config, listener); + validateConfigurationInternal(tenantId, config, listener); } }, listener::onFailure)); } else { // Normal validation flow - validateConfigurationInternal(config, listener); + validateConfigurationInternal(tenantId, config, listener); } } - private void validateConfigurationInternal(MemoryConfiguration config, ActionListener listener) { + private void validateConfigurationInternal(String tenantId, MemoryConfiguration config, ActionListener listener) { // Validate that strategies have required AI models try { MemoryConfiguration.validateStrategiesRequireModels(config); @@ -481,10 +481,11 @@ private void validateConfigurationInternal(MemoryConfiguration config, ActionLis } // Validate LLM model using helper - MemoryContainerModelValidator.validateLlmModel(config.getLlmId(), mlModelManager, client, ActionListener.wrap(isValid -> { + MemoryContainerModelValidator.validateLlmModel(tenantId, config.getLlmId(), mlModelManager, client, ActionListener.wrap(isValid -> { // LLM model is valid, now validate embedding model MemoryContainerModelValidator .validateEmbeddingModel( + tenantId, config.getEmbeddingModelId(), config.getEmbeddingModelType(), mlModelManager, diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java index db6d55a412..e05126322e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java @@ -284,11 +284,13 @@ private void validateAndCreateIndices( String memoryContainerId, ActionListener listener ) { + String tenantId = container.getTenantId(); // Validate LLM model using helper - MemoryContainerModelValidator.validateLlmModel(config.getLlmId(), mlModelManager, client, ActionListener.wrap(llmValid -> { + MemoryContainerModelValidator.validateLlmModel(tenantId, config.getLlmId(), mlModelManager, client, ActionListener.wrap(llmValid -> { // LLM validated, now validate embedding model MemoryContainerModelValidator .validateEmbeddingModel( + tenantId, config.getEmbeddingModelId(), config.getEmbeddingModelType(), mlModelManager, diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerModelValidator.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerModelValidator.java index b625d4ed13..ecc557a50e 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerModelValidator.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerModelValidator.java @@ -32,12 +32,13 @@ public final class MemoryContainerModelValidator { /** * Validates that the LLM model exists and is of REMOTE type. * + * @param tenantId the tenant id. This is necessary for multi-tenancy. * @param llmId The LLM model ID to validate * @param modelManager The ML model manager * @param client The OpenSearch client * @param listener Action listener that receives true on success, or error on failure */ - public static void validateLlmModel(String llmId, MLModelManager modelManager, Client client, ActionListener listener) { + public static void validateLlmModel(String tenantId, String llmId, MLModelManager modelManager, Client client, ActionListener listener) { if (llmId == null) { listener.onResponse(true); return; @@ -55,7 +56,7 @@ public static void validateLlmModel(String llmId, MLModelManager modelManager, C listener.onFailure(new IllegalArgumentException(String.format(LLM_MODEL_NOT_FOUND_ERROR, llmId))); }), context::restore); - modelManager.getModel(llmId, wrappedListener); + modelManager.getModel(llmId, tenantId, wrappedListener); } } @@ -69,6 +70,7 @@ public static void validateLlmModel(String llmId, MLModelManager modelManager, C * @param listener Action listener that receives true on success, or error on failure */ public static void validateEmbeddingModel( + String tenantId, String embeddingModelId, FunctionName expectedType, MLModelManager modelManager, @@ -99,7 +101,7 @@ public static void validateEmbeddingModel( listener.onFailure(new IllegalArgumentException(String.format(EMBEDDING_MODEL_NOT_FOUND_ERROR, embeddingModelId))); }), context::restore); - modelManager.getModel(embeddingModelId, wrappedListener); + modelManager.getModel(embeddingModelId, tenantId, wrappedListener); } } } From a7ffb02cb40c8a8e8b673cfa11e69e730a1bba80 Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Fri, 7 Nov 2025 05:16:06 -0800 Subject: [PATCH 28/58] remote conversation memory with REST call (#4406) --- .../opensearch/ml/common/MLMemoryType.java | 3 +- .../connector/MLExecuteConnectorRequest.java | 7 +- .../engine/algorithms/agent/AgentUtils.java | 58 +- .../algorithms/agent/MLAgentExecutor.java | 58 +- .../algorithms/agent/MLChatAgentRunner.java | 14 +- .../MLConversationalFlowAgentRunner.java | 3 +- .../MLPlanExecuteAndReflectAgentRunner.java | 22 +- .../memory/AgenticConversationMemory.java | 17 +- .../RemoteAgenticConversationMemory.java | 1225 +++++++++++++++++ .../memory/AgenticConversationMemoryTest.java | 2 +- .../ExecuteConnectorTransportAction.java | 47 +- .../TransportUpdateMemoryContainerAction.java | 31 +- .../helper/MemoryContainerModelValidator.java | 8 +- .../ml/helper/RemoteMemoryStoreHelper.java | 2 +- .../ml/plugin/MachineLearningPlugin.java | 5 + 15 files changed, 1430 insertions(+), 72 deletions(-) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java diff --git a/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java b/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java index 31939ce1ca..45b82db53d 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java +++ b/common/src/main/java/org/opensearch/ml/common/MLMemoryType.java @@ -9,7 +9,8 @@ public enum MLMemoryType { CONVERSATION_INDEX, - AGENTIC_MEMORY; + AGENTIC_MEMORY, + REMOTE_AGENTIC_MEMORY; public static MLMemoryType from(String value) { if (value != null) { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java index 9b24115455..7e9dac8857 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLExecuteConnectorRequest.java @@ -67,7 +67,9 @@ public ActionRequestValidationException validate() { } else if (this.mlInput.getInputDataset() == null) { exception = addValidationError("input data can't be null", exception); } - + if (this.connectorId == null) { + exception = addValidationError("connectorId can't be null", exception); + } return exception; } @@ -82,8 +84,7 @@ public static MLExecuteConnectorRequest fromActionRequest(ActionRequest actionRe return new MLExecuteConnectorRequest(input); } } catch (IOException e) { - throw new UncheckedIOException("failed to parse ActionRequest into MLPredictionTaskRequest", e); + throw new UncheckedIOException("failed to parse ActionRequest into MLExecuteConnectorRequest", e); } - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 97d45a4129..4b3cf23511 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -10,7 +10,9 @@ import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD; import static org.opensearch.ml.common.CommonValue.MCP_CONNECTOR_ID_FIELD; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; -import static org.opensearch.ml.common.agent.MLMemorySpec.MEMORY_CONTAINER_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; +import static org.opensearch.ml.common.memorycontainer.RemoteStore.CREDENTIAL_FIELD; +import static org.opensearch.ml.common.memorycontainer.RemoteStore.ENDPOINT_FIELD; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.isJson; @@ -65,6 +67,7 @@ import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesReference; @@ -73,8 +76,10 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.agent.MLMemorySpec; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.connector.McpConnector; import org.opensearch.ml.common.connector.McpStreamableHttpConnector; import org.opensearch.ml.common.output.model.ModelTensor; @@ -1024,17 +1029,60 @@ public static Map createMemoryParams( String memoryId, String appType, MLAgent mlAgent, - String memoryContainerId + Map requestParameters ) { Map memoryParams = new HashMap<>(); memoryParams.put(ConversationIndexMemory.MEMORY_NAME, question); memoryParams.put(ConversationIndexMemory.MEMORY_ID, memoryId); memoryParams.put(APP_TYPE, appType); - if (mlAgent.getMemory().getMemoryContainerId() != null) { - memoryParams.put(MEMORY_CONTAINER_ID_FIELD, mlAgent.getMemory().getMemoryContainerId()); + MLMemorySpec agentMemory = mlAgent != null ? mlAgent.getMemory() : null; + if (agentMemory != null) { + String containerId = agentMemory.getMemoryContainerId(); + if (!Strings.isNullOrEmpty(containerId)) { + memoryParams.put(MEMORY_CONTAINER_ID_FIELD, containerId); + } + memoryParams.put(TENANT_ID_FIELD, mlAgent.getTenantId()); + } + if (requestParameters != null) { + String endpointParam = requestParameters.get(ENDPOINT_FIELD); + if (!Strings.isNullOrEmpty(endpointParam)) { + memoryParams.put(ENDPOINT_FIELD, endpointParam); + } + String regionParam = requestParameters.get(HttpConnector.REGION_FIELD); + if (!Strings.isNullOrEmpty(regionParam)) { + memoryParams.put(HttpConnector.REGION_FIELD, regionParam); + } + Map credential = parseStringMapParameter(requestParameters.get(CREDENTIAL_FIELD), CREDENTIAL_FIELD); + if (credential != null && !credential.isEmpty()) { + memoryParams.put(CREDENTIAL_FIELD, credential); + } + // Extract user_id if provided + String userIdParam = requestParameters.get("user_id"); + if (!Strings.isNullOrEmpty(userIdParam)) { + memoryParams.put("user_id", userIdParam); + } memoryParams.put(TENANT_ID_FIELD, mlAgent.getTenantId()); } - memoryParams.putIfAbsent(MEMORY_CONTAINER_ID_FIELD, memoryContainerId); return memoryParams; } + + private static Map parseStringMapParameter(String rawValue, String fieldName) { + if (Strings.isNullOrEmpty(rawValue)) { + return null; + } + try ( + XContentParser parser = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, rawValue) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Map parsed = parser.mapStrings(); + if (parsed == null || parsed.isEmpty()) { + return null; + } + return parsed; + } catch (IOException ex) { + log.warn("Failed to parse {} field; ignoring value", fieldName, ex); + return null; + } + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 87b26b2957..54f8c3d8bf 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -235,6 +235,38 @@ public void execute(Input input, ActionListener listener, TransportChann RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentMLInput .getInputDataset(); MLMemorySpec memorySpec = mlAgent.getMemory(); + Map requestParameters = inputDataSet.getParameters(); + String containerOverride = null; + if (requestParameters != null && requestParameters.containsKey(MEMORY_CONTAINER_ID_FIELD)) { + String containerParam = requestParameters.get(MEMORY_CONTAINER_ID_FIELD); + if (!Strings.isNullOrEmpty(containerParam)) { + containerOverride = containerParam; + } + } + if (containerOverride != null) { + if (memorySpec == null) { + throw new IllegalArgumentException( + "memory_container_id override requires the agent to be configured with memory" + ); + } + String currentContainerId = memorySpec.getMemoryContainerId(); + if (!containerOverride.equals(currentContainerId)) { + MLMemorySpec updatedSpec = memorySpec + .toBuilder() + .memoryContainerId(containerOverride) + .build(); + mlAgent = mlAgent.toBuilder().memory(updatedSpec).build(); + memorySpec = updatedSpec; + log + .debug( + "Agent {} overriding memory container from {} to {}", + agentId, + currentContainerId, + containerOverride + ); + } + } + final MLAgent finalMlAgent = mlAgent; String memoryId = inputDataSet.getParameters().get(MEMORY_ID); String parentInteractionId = inputDataSet.getParameters().get(PARENT_INTERACTION_ID); String regenerateInteractionId = inputDataSet.getParameters().get(REGENERATE_INTERACTION_ID); @@ -266,16 +298,24 @@ public void execute(Input input, ActionListener listener, TransportChann && memorySpec.getType() != null && memoryFactoryMap.containsKey(MLMemoryType.from(memorySpec.getType()).name()) && (memoryId == null || parentInteractionId == null)) { - Memory.Factory> memoryFactory = memoryFactoryMap - .get(MLMemoryType.from(memorySpec.getType()).name()); - Map memoryParams = createMemoryParams( question, memoryId, appType, mlAgent, - inputDataSet.getParameters().get(MEMORY_CONTAINER_ID_FIELD) + requestParameters ); + + // Check if inline connector metadata is present to use RemoteAgenticConversationMemory + Memory.Factory> memoryFactory; + if (memoryParams != null && memoryParams.containsKey("endpoint")) { + // Use RemoteAgenticConversationMemory when inline connector metadata is detected + memoryFactory = memoryFactoryMap.get(MLMemoryType.REMOTE_AGENTIC_MEMORY.name()); + log.info("Detected inline connector metadata, using RemoteAgenticConversationMemory"); + } else { + // Use the originally specified memory factory + memoryFactory = memoryFactoryMap.get(MLMemoryType.from(memorySpec.getType()).name()); + } memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { inputDataSet.getParameters().put(MEMORY_ID, memory.getId()); // get question for regenerate @@ -297,7 +337,7 @@ public void execute(Input input, ActionListener listener, TransportChann isAsync, outputs, modelTensors, - mlAgent, + finalMlAgent, channel, hookRegistry ); @@ -315,7 +355,7 @@ public void execute(Input input, ActionListener listener, TransportChann isAsync, outputs, modelTensors, - mlAgent, + finalMlAgent, channel, hookRegistry ); @@ -340,8 +380,8 @@ public void execute(Input input, ActionListener listener, TransportChann question, memoryId, appType, - mlAgent, - inputDataSet.getParameters().get(MEMORY_CONTAINER_ID_FIELD) + finalMlAgent, + requestParameters ); factory @@ -354,7 +394,7 @@ public void execute(Input input, ActionListener listener, TransportChann mlTask, isAsync, memoryId, - mlAgent, + finalMlAgent, outputs, modelTensors, listener, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index f371e5244c..21bdf26fd5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -7,7 +7,6 @@ import static org.opensearch.ml.common.conversation.ActionConstants.ADDITIONAL_INFO_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; -import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.common.utils.StringUtils.processTextDoc; import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput; @@ -206,9 +205,18 @@ public void run(MLAgent mlAgent, Map inputParams, ActionListener String chatHistoryResponseTemplate = params.get(CHAT_HISTORY_RESPONSE_TEMPLATE); int messageHistoryLimit = getMessageHistoryLimit(params); - Memory.Factory> memoryFactory = memoryFactoryMap.get(memoryType); + Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent, params); - Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent, params.get(MEMORY_CONTAINER_ID_FIELD)); + // Check if inline connector metadata is present to use RemoteAgenticConversationMemory + Memory.Factory> memoryFactory; + if (memoryParams != null && memoryParams.containsKey("endpoint")) { + // Use RemoteAgenticConversationMemory when inline connector metadata is detected + memoryFactory = memoryFactoryMap.get(MLMemoryType.REMOTE_AGENTIC_MEMORY.name()); + log.info("Detected inline connector metadata, using RemoteAgenticConversationMemory"); + } else { + // Use the originally specified memory factory + memoryFactory = memoryFactoryMap.get(memoryType); + } memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { // TODO: call runAgent directly if messageHistoryLimit == 0 memory.getMessages(messageHistoryLimit, ActionListener.>wrap(r -> { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index faeec6b050..f4671f1823 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -9,7 +9,6 @@ import static org.opensearch.ml.common.conversation.ActionConstants.AI_RESPONSE_FIELD; import static org.opensearch.ml.common.conversation.ActionConstants.MEMORY_ID; import static org.opensearch.ml.common.conversation.ActionConstants.PARENT_INTERACTION_ID_FIELD; -import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.ToolUtils.TOOL_OUTPUT_FILTERS_FIELD; import static org.opensearch.ml.common.utils.ToolUtils.convertOutputToModelTensor; import static org.opensearch.ml.common.utils.ToolUtils.filterToolOutput; @@ -112,7 +111,7 @@ public void run(MLAgent mlAgent, Map params, ActionListener> memoryFactory = memoryFactoryMap.get(memoryType); - Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent, params.get(MEMORY_CONTAINER_ID_FIELD)); + Map memoryParams = createMemoryParams(title, memoryId, appType, mlAgent, params); memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { memory.getMessages(messageHistoryLimit, ActionListener.>wrap(r -> { List messageList = new ArrayList<>(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 776d85884c..f70ddc9f50 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -9,7 +9,6 @@ import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD; -import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly; import static org.opensearch.ml.common.utils.StringUtils.isJson; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE; @@ -76,7 +75,6 @@ import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; -import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; @@ -305,13 +303,7 @@ public void run(MLAgent mlAgent, Map apiParams, ActionListener> memoryFactory = memoryFactoryMap.get(memoryType); - Map memoryParams = createMemoryParams( - apiParams.get(USER_PROMPT_FIELD), - memoryId, - appType, - mlAgent, - apiParams.get(MEMORY_CONTAINER_ID_FIELD) - ); + Map memoryParams = createMemoryParams(apiParams.get(USER_PROMPT_FIELD), memoryId, appType, mlAgent, apiParams); memoryFactory.create(memoryParams, ActionListener.wrap(memory -> { memory.getMessages(messageHistoryLimit, ActionListener.>wrap(interactions -> { final List completedSteps = new ArrayList<>(); @@ -359,7 +351,17 @@ private void setToolsAndRunAgent( AtomicInteger traceNumber = new AtomicInteger(0); - executePlanningLoop(mlAgent.getLlm(), allParams, completedSteps, memory, conversationId, 0, traceNumber, mlAgent.getTenantId(), finalListener); + executePlanningLoop( + mlAgent.getLlm(), + allParams, + completedSteps, + memory, + conversationId, + 0, + traceNumber, + mlAgent.getTenantId(), + finalListener + ); }; // Fetch MCP tools and handle both success and failure cases diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java index 7b8361f184..29c65d9979 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/AgenticConversationMemory.java @@ -604,11 +604,18 @@ public void create( /** * Create a new session in the memory container using the new session API */ - private void createSessionInMemoryContainer(String summary, String memoryContainerId, String tenantId, ActionListener listener) { - MLCreateSessionInput input = MLCreateSessionInput.builder(). - memoryContainerId(memoryContainerId). - tenantId(tenantId). - summary(summary).build(); + private void createSessionInMemoryContainer( + String summary, + String memoryContainerId, + String tenantId, + ActionListener listener + ) { + MLCreateSessionInput input = MLCreateSessionInput + .builder() + .memoryContainerId(memoryContainerId) + .tenantId(tenantId) + .summary(summary) + .build(); MLCreateSessionRequest request = MLCreateSessionRequest.builder().mlCreateSessionInput(input).build(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java new file mode 100644 index 0000000000..c0c74b0d23 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java @@ -0,0 +1,1225 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.memory; + +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_ID; +import static org.opensearch.ml.engine.memory.ConversationIndexMemory.MEMORY_NAME; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLMemoryType; +import org.opensearch.ml.common.connector.AwsConnector; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.memory.Memory; +import org.opensearch.ml.common.memory.Message; +import org.opensearch.ml.common.memorycontainer.MLWorkingMemory; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesResponse; +import org.opensearch.ml.common.transport.memorycontainer.memory.MLGetMemoryResponse; +import org.opensearch.ml.common.transport.memorycontainer.memory.MemoryResult; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.script.ScriptService; +import org.opensearch.search.SearchHit; +import org.opensearch.transport.client.Client; + +import com.google.gson.Gson; + +import lombok.Getter; +import lombok.extern.log4j.Log4j2; + +/** + * Remote agentic memory implementation backed by connector-defined REST APIs. + */ +@Log4j2 +@Getter +public class RemoteAgenticConversationMemory implements Memory { + + public static final String TYPE = MLMemoryType.REMOTE_AGENTIC_MEMORY.name(); + private static final String SESSION_ID_FIELD = "session_id"; + private static final String CREATED_TIME_FIELD = "created_time"; + private static final Gson GSON = new Gson(); + + private final String conversationId; + private final String memoryContainerId; + private final String userId; + private final Connector connector; + private final RemoteConnectorExecutor executor; + + // Dependencies for connector execution + private final ScriptService scriptService; + private final ClusterService clusterService; + private final Client client; + private final NamedXContentRegistry xContentRegistry; + + public RemoteAgenticConversationMemory( + String memoryId, + String memoryContainerId, + String userId, + Connector connector, + ScriptService scriptService, + ClusterService clusterService, + Client client, + NamedXContentRegistry xContentRegistry + ) { + this.conversationId = memoryId; + this.memoryContainerId = memoryContainerId; + this.userId = userId; + this.connector = connector; + this.scriptService = scriptService; + this.clusterService = clusterService; + this.client = client; + this.xContentRegistry = xContentRegistry; + + // Initialize the executor for the connector + this.executor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); + this.executor.setScriptService(scriptService); + this.executor.setClusterService(clusterService); + this.executor.setClient(client); + this.executor.setXContentRegistry(xContentRegistry); + + // Log creation for debugging/monitoring + log + .info( + "RemoteAgenticConversationMemory created - sessionId: {}, containerId: {}, endpoint: {}, protocol: {}", + memoryId, + memoryContainerId, + connector.getParameters() != null ? connector.getParameters().get("endpoint") : "unknown", + connector.getProtocol() + ); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getId() { + return conversationId; + } + + @Override + public void save(Message message, String parentId, Integer traceNum, String action) { + this.save(message, parentId, traceNum, action, ActionListener.wrap(r -> { + log.info("Saved message to remote agentic memory, session id: {}, working memory id: {}", conversationId, r.getId()); + }, e -> { log.error("Failed to save message to remote agentic memory", e); })); + } + + @Override + public void save( + Message message, + String parentId, + Integer traceNum, + String action, + ActionListener listener + ) { + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener + .onFailure( + new IllegalStateException( + "Memory container ID is not configured for this RemoteAgenticConversationMemory. " + + "Cannot save messages without a valid memory container." + ) + ); + return; + } + + ConversationIndexMessage msg = (ConversationIndexMessage) message; + + // Build namespace with session_id and optionally user_id + Map namespace = new HashMap<>(); + namespace.put(SESSION_ID_FIELD, conversationId); + if (!Strings.isNullOrEmpty(userId)) { + namespace.put("user_id", userId); + } + + // Simple rule matching ConversationIndexMemory: + // - If traceNum != null → it's a trace + // - If traceNum == null → it's a message + boolean isTrace = (traceNum != null); + + Map metadata = new HashMap<>(); + Map structuredData = new HashMap<>(); + + // Store data in structured_data format matching conversation index + structuredData.put("input", msg.getQuestion() != null ? msg.getQuestion() : ""); + structuredData.put("response", msg.getResponse() != null ? msg.getResponse() : ""); + + if (isTrace) { + // This is a trace (tool usage or intermediate step) + metadata.put("type", "trace"); + if (parentId != null) { + metadata.put("parent_message_id", parentId); + structuredData.put("parent_message_id", parentId); + } + metadata.put("trace_number", String.valueOf(traceNum)); + structuredData.put("trace_number", traceNum); + if (action != null) { + metadata.put("origin", action); + structuredData.put("origin", action); + } + } else { + // This is a final message (Q&A pair) + metadata.put("type", "message"); + if (msg.getFinalAnswer() != null) { + structuredData.put("final_answer", msg.getFinalAnswer()); + } + } + + // Add timestamps + java.time.Instant now = java.time.Instant.now(); + structuredData.put("create_time", now.toString()); + structuredData.put("updated_time", now.toString()); + + // Build request body for add_memory action + Map requestBody = new HashMap<>(); + requestBody.put("memory_container_id", memoryContainerId); + requestBody.put("structured_data", structuredData); + requestBody.put("message_id", traceNum); // Store trace number in messageId field (null for messages) + requestBody.put("namespace", namespace); + requestBody.put("metadata", metadata); + requestBody.put("infer", false); // Don't infer long-term memory by default + + // Execute the connector action + executeConnectorAction("add_memory", requestBody, ActionListener.wrap(response -> { + // Parse response using proper Response class + MLAddMemoriesResponse addResponse = parseAddMemoryResponse(response); + String workingMemoryId = addResponse.getWorkingMemoryId(); + CreateInteractionResponse interactionResponse = new CreateInteractionResponse(workingMemoryId); + listener.onResponse(interactionResponse); + }, e -> { + log.error("Failed to add memories to remote memory container", e); + listener.onFailure(e); + })); + } + + @Override + public void update(String messageId, Map updateContent, ActionListener updateListener) { + if (Strings.isNullOrEmpty(memoryContainerId)) { + updateListener + .onFailure(new IllegalStateException("Memory container ID is not configured for this RemoteAgenticConversationMemory")); + return; + } + + // Use retry mechanism for AOSS compatibility (high refresh latency) + updateWithRetry(messageId, updateContent, updateListener, 0); + } + + /** + * Update with retry mechanism to handle AOSS refresh latency (up to 10s) + * Uses exponential backoff: 500ms, 1s, 2s, 4s, 8s + */ + private void updateWithRetry( + String messageId, + Map updateContent, + ActionListener updateListener, + int attemptNumber + ) { + final int maxRetries = 5; + final long baseDelayMs = 500; + + // Step 1: Get the existing working memory to retrieve current structured_data + Map getRequest = new HashMap<>(); + getRequest.put("memory_container_id", memoryContainerId); + getRequest.put("memory_type", "working"); + getRequest.put("memory_id", messageId); + + executeConnectorAction("get_memory", getRequest, ActionListener.wrap(getResponse -> { + // Step 2: Extract existing structured_data using proper Response class + MLGetMemoryResponse memoryResponse = parseGetMemoryResponse(getResponse); + MLWorkingMemory workingMemory = memoryResponse.getWorkingMemory(); + + Map structuredData; + if (workingMemory == null || workingMemory.getStructuredData() == null) { + structuredData = new HashMap<>(); + } else { + // Create a mutable copy + structuredData = new HashMap<>(workingMemory.getStructuredData()); + } + + // Step 3: Merge update content into structured_data + for (Map.Entry entry : updateContent.entrySet()) { + structuredData.put(entry.getKey(), entry.getValue()); + } + + // Step 4: Create update request with merged structured_data + Map finalUpdateContent = new HashMap<>(); + finalUpdateContent.put("structured_data", structuredData); + + Map updateRequest = new HashMap<>(); + updateRequest.put("memory_container_id", memoryContainerId); + updateRequest.put("memory_type", "working"); + updateRequest.put("memory_id", messageId); + updateRequest.put("update_content", finalUpdateContent); + + // Step 5: Execute the update + executeConnectorAction("update_memory", updateRequest, ActionListener.wrap(response -> { + try { + // Parse using standard UpdateResponse parser + UpdateResponse updateResponse = parseUpdateResponse(response); + updateListener.onResponse(updateResponse); + } catch (Exception parseException) { + log.error("Failed to parse update response from remote memory", parseException); + updateListener.onFailure(parseException); + } + }, e -> { + log.error("Failed to update memory in remote memory container", e); + updateListener.onFailure(e); + })); + }, e -> { + // Check if it's a 404 (document not found) and we haven't exceeded max retries + boolean isNotFound = e.getMessage() != null && (e.getMessage().contains("404") || e.getMessage().contains("\"found\":false")); + + if (isNotFound && attemptNumber < maxRetries) { + // Calculate delay with exponential backoff + long delayMs = baseDelayMs * (1L << attemptNumber); + + log + .warn( + "Document not found (attempt {}/{}), retrying after {}ms due to refresh latency. MessageId: {}", + attemptNumber + 1, + maxRetries, + delayMs, + messageId + ); + + // Schedule retry after delay + try { + Thread.sleep(delayMs); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + updateListener.onFailure(new RuntimeException("Retry interrupted", ie)); + return; + } + + // Retry + updateWithRetry(messageId, updateContent, updateListener, attemptNumber + 1); + } else { + if (attemptNumber >= maxRetries) { + log.error("Failed to get existing memory after {} retries. MessageId: {}", maxRetries, messageId, e); + } else { + log.error("Failed to get existing memory for update. MessageId: {}", messageId, e); + } + updateListener.onFailure(e); + } + })); + } + + @Override + public void getMessages(int size, ActionListener> listener) { + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener.onFailure(new IllegalStateException("Memory container ID is not configured for this RemoteAgenticConversationMemory")); + return; + } + + // Build search query for working memory by session_id, filtering only final messages (not traces) + Map query = new HashMap<>(); + Map bool = new HashMap<>(); + List> must = new ArrayList<>(); + List> mustNot = new ArrayList<>(); + + // Must match session_id + Map sessionTerm = new HashMap<>(); + sessionTerm.put("namespace." + SESSION_ID_FIELD, conversationId); + must.add(Map.of("term", sessionTerm)); + + // Must not have trace_number (exclude traces) + mustNot.add(Map.of("exists", Map.of("field", "structured_data.trace_number"))); + + bool.put("must", must); + bool.put("must_not", mustNot); + query.put("bool", bool); + + // Build search request + Map searchRequest = new HashMap<>(); + searchRequest.put("memory_container_id", memoryContainerId); + searchRequest.put("memory_type", "working"); + searchRequest.put("query", query); + searchRequest.put("size", size); + searchRequest.put("sort", List.of(Map.of(CREATED_TIME_FIELD, "asc"))); + + executeConnectorAction("search_memories", searchRequest, ActionListener.wrap(response -> { + List interactions = parseSearchResponseToInteractions(response); + listener.onResponse(interactions); + }, e -> { + log.error("Failed to search memories in remote memory container", e); + listener.onFailure(e); + })); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("clear method is not supported in RemoteAgenticConversationMemory"); + } + + @Override + public void deleteInteractionAndTrace(String interactionId, ActionListener listener) { + // For now, delegate to a simple implementation + // In the future, this could use delete_memory action + log.warn("deleteInteractionAndTrace is not fully implemented for RemoteAgenticConversationMemory"); + listener.onResponse(false); + } + + /** + * Get traces (intermediate steps/tool usage) for a specific parent message + * @param parentMessageId The parent message ID + * @param listener Action listener for the traces + */ + public void getTraces(String parentMessageId, ActionListener> listener) { + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener.onFailure(new IllegalStateException("Memory container ID is not configured for this RemoteAgenticConversationMemory")); + return; + } + + // Build search query for traces by parent_message_id + Map query = new HashMap<>(); + Map bool = new HashMap<>(); + List> must = new ArrayList<>(); + + // Must match session_id + Map sessionTerm = new HashMap<>(); + sessionTerm.put("namespace." + SESSION_ID_FIELD, conversationId); + must.add(Map.of("term", sessionTerm)); + + // Must be a trace + Map typeTerm = new HashMap<>(); + typeTerm.put("metadata.type", "trace"); + must.add(Map.of("term", typeTerm)); + + // Must have specific parent_message_id + Map parentTerm = new HashMap<>(); + parentTerm.put("metadata.parent_message_id", parentMessageId); + must.add(Map.of("term", parentTerm)); + + bool.put("must", must); + query.put("bool", bool); + + // Build search request + Map searchRequest = new HashMap<>(); + searchRequest.put("memory_container_id", memoryContainerId); + searchRequest.put("memory_type", "working"); + searchRequest.put("query", query); + searchRequest.put("size", 1000); // Get all traces for this message + searchRequest.put("sort", List.of(Map.of("message_id", "asc"))); // Sort by trace number + + executeConnectorAction("search_memories", searchRequest, ActionListener.wrap(response -> { + List traces = parseSearchResponseToTraces(response); + listener.onResponse(traces); + }, e -> { + log.error("Failed to search traces in remote memory container", e); + listener.onFailure(e); + })); + } + + /** + * Helper method to execute connector actions + */ + private void executeConnectorAction(String action, Map parameters, ActionListener listener) { + // Log the action being executed for debugging + if (log.isDebugEnabled()) { + Map actionDebug = new HashMap<>(); + actionDebug.put("action", action); + + // Log parameter keys but mask values that might be sensitive + if (parameters != null && !parameters.isEmpty()) { + Map paramKeys = new HashMap<>(); + for (String key : parameters.keySet()) { + Object value = parameters.get(key); + paramKeys.put(key, value); + } + actionDebug.put("parameters", paramKeys); + } + + log.debug("Executing RemoteAgenticConversationMemory action: {}", GSON.toJson(actionDebug)); + } + + // Use the cached executor that was initialized in the constructor + // The executor already has the connector with all actions defined and decrypted credentials + + // Prepare parameters for the action + Map inputParams = new HashMap<>(); + + // Add required parameters that match the URL template placeholders + inputParams.put("memory_container_id", (String) parameters.get("memory_container_id")); + + if (parameters.containsKey("memory_id")) { + inputParams.put("memory_id", (String) parameters.get("memory_id")); + } + if (parameters.containsKey("memory_type")) { + inputParams.put("memory_type", (String) parameters.get("memory_type")); + } + + // Build request body based on action type + String requestBody = buildRequestBody(action, parameters); + if (requestBody != null) { + inputParams.put("body", requestBody); + } + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(inputParams).build(); + + MLInput mlInput = MLInput.builder().algorithm(FunctionName.CONNECTOR).inputDataset(inputDataSet).build(); + + // Execute the action - the executor will find the action by name in the connector's actions list + this.executor.executeAction(action, mlInput, ActionListener.wrap(response -> { + String output; + + // Try to extract the actual data from ModelTensorOutput wrapper + Map dataMap = extractDataFromModelTensorOutput(response); + + if (dataMap != null) { + // Convert the extracted data map back to JSON + output = GSON.toJson(dataMap); + } else if (response instanceof MLTaskResponse) { + // Fallback: if extraction fails, use toString() as before + MLTaskResponse taskResponse = (MLTaskResponse) response; + output = taskResponse.getOutput() != null ? taskResponse.getOutput().toString() : "{}"; + } else { + output = response.toString(); + } + + // Log successful response for debugging (truncate if too long) + if (log.isDebugEnabled()) { + String debugOutput = output; + if (debugOutput.length() > 500) { + debugOutput = debugOutput.substring(0, 500) + "... [truncated]"; + } + log.debug("RemoteAgenticConversationMemory action '{}' response: {}", action, debugOutput); + } + + listener.onResponse(output); + }, e -> { + log.error("Failed to execute connector action '{}' for RemoteAgenticConversationMemory: {}", action, e.getMessage(), e); + listener.onFailure(e); + })); + } + + /** + * Build request body based on action type + */ + private String buildRequestBody(String action, Map parameters) { + switch (action) { + case "add_memory": + return GSON.toJson(parameters); + case "search_memories": + Map searchBody = new HashMap<>(); + searchBody.put("query", parameters.get("query")); + if (parameters.containsKey("size")) { + searchBody.put("size", parameters.get("size")); + } + if (parameters.containsKey("sort")) { + searchBody.put("sort", parameters.get("sort")); + } + return GSON.toJson(searchBody); + case "update_memory": + return GSON.toJson(parameters.get("update_content")); + case "create_session": + Map sessionBody = new HashMap<>(); + if (parameters.containsKey("summary")) { + sessionBody.put("summary", parameters.get("summary")); + } + return GSON.toJson(sessionBody); + default: + return null; + } + } + + /** + * Parse JSON string into SearchResponse using OpenSearch's standard parser + */ + private SearchResponse parseSearchResponse(String jsonResponse) throws IOException { + try (XContentParser parser = jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, jsonResponse)) { + return SearchResponse.fromXContent(parser); + } + } + + /** + * Parse JSON string into GetResponse using OpenSearch's standard parser + */ + private GetResponse parseGetResponse(String jsonResponse) throws IOException { + try (XContentParser parser = jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, jsonResponse)) { + return GetResponse.fromXContent(parser); + } + } + + /** + * Parse JSON string into UpdateResponse using OpenSearch's standard parser + */ + private UpdateResponse parseUpdateResponse(String jsonResponse) throws IOException { + try (XContentParser parser = jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, jsonResponse)) { + return UpdateResponse.fromXContent(parser); + } + } + + /** + * Parse add memory response using XContentParser + */ + private MLAddMemoriesResponse parseAddMemoryResponse(String jsonResponse) { + try (XContentParser parser = jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, jsonResponse)) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + String workingMemoryId = null; + String sessionId = null; + List results = new ArrayList<>(); + + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case "working_memory_id": + workingMemoryId = parser.text(); + break; + case "session_id": + sessionId = parser.text(); + break; + case "long_term_memories": + // Parse array if needed in future + parser.skipChildren(); + break; + default: + parser.skipChildren(); + } + } + + return MLAddMemoriesResponse.builder().workingMemoryId(workingMemoryId).sessionId(sessionId).results(results).build(); + } catch (Exception e) { + log.error("Failed to parse add memory response: " + jsonResponse, e); + // Return a minimal response with null values + return MLAddMemoriesResponse.builder().build(); + } + } + + /** + * Parse get memory response using XContentParser + * Following the pattern from MLGetMemoryResponse.fromGetResponse + */ + private MLGetMemoryResponse parseGetMemoryResponse(String jsonResponse) { + try (XContentParser parser = jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, jsonResponse)) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + // Parse the entire response as MLWorkingMemory (matching MLGetMemoryResponse.fromGetResponse pattern) + MLWorkingMemory workingMemory = MLWorkingMemory.parse(parser); + + return MLGetMemoryResponse.builder().workingMemory(workingMemory).build(); + } catch (Exception e) { + log.error("Failed to parse get memory response: " + jsonResponse, e); + return MLGetMemoryResponse.builder().build(); + } + } + + private List parseSearchResponseToInteractions(String response) { + try { + SearchResponse searchResponse = parseSearchResponse(response); + return convertSearchHitsToMessages(searchResponse); + } catch (Exception e) { + log.error("Failed to parse search response: " + response, e); + return new ArrayList<>(); + } + } + + /** + * Convert SearchResponse hits to Message list + */ + private List convertSearchHitsToMessages(SearchResponse searchResponse) { + List messages = new ArrayList<>(); + if (searchResponse.getHits() != null && searchResponse.getHits().getHits() != null) { + for (SearchHit hit : searchResponse.getHits().getHits()) { + try { + Interaction interaction = convertHitToInteraction(hit, false); + if (interaction != null) { + messages.add(interaction); + } + } catch (Exception e) { + log.warn("Failed to parse hit: " + hit.getId(), e); + } + } + } + return messages; + } + + /** + * Convert a SearchHit to an Interaction object + * @param hit The SearchHit from the SearchResponse + * @param isTrace Whether this is a trace (true) or a message (false) + * @return Interaction object or null if conversion fails + */ + private Interaction convertHitToInteraction(SearchHit hit, boolean isTrace) { + String id = hit.getId(); + Map source = hit.getSourceAsMap(); + + if (source != null && source.containsKey("structured_data")) { + Map structuredData = (Map) source.get("structured_data"); + + String input = (String) structuredData.getOrDefault("input", ""); + String responseText = (String) structuredData.getOrDefault("response", ""); + + // For traces, extract origin from structured_data + String origin = isTrace ? (String) structuredData.getOrDefault("origin", "remote_agentic_memory") : "remote_agentic_memory"; + + // Extract timestamps + Long createdTimeMs = source.containsKey("created_time") ? ((Number) source.get("created_time")).longValue() : null; + Long updatedTimeMs = source.containsKey("last_updated_time") ? ((Number) source.get("last_updated_time")).longValue() : null; + + java.time.Instant createTime = createdTimeMs != null ? java.time.Instant.ofEpochMilli(createdTimeMs) : java.time.Instant.now(); + java.time.Instant updatedTime = updatedTimeMs != null ? java.time.Instant.ofEpochMilli(updatedTimeMs) : null; + + // Extract metadata + String parentInteractionId = null; + Integer traceNumber = null; + + if (source.containsKey("metadata")) { + Map metadata = (Map) source.get("metadata"); + parentInteractionId = (String) metadata.get("parent_message_id"); + + // For traces, extract trace number + if (isTrace && metadata.containsKey("trace_number")) { + Object traceNum = metadata.get("trace_number"); + if (traceNum instanceof Number) { + traceNumber = ((Number) traceNum).intValue(); + } + } + } + + // For traces, also check parent_message_id in structured_data + if (isTrace && parentInteractionId == null && structuredData.containsKey("parent_message_id")) { + parentInteractionId = (String) structuredData.get("parent_message_id"); + } + + // For traces, extract trace number from structured_data if not found in metadata + if (isTrace && traceNumber == null) { + if (structuredData.containsKey("trace_number")) { + Object traceNum = structuredData.get("trace_number"); + if (traceNum instanceof Number) { + traceNumber = ((Number) traceNum).intValue(); + } + } + // Also check message_id field at root level for traces + if (traceNumber == null && source.containsKey("message_id")) { + Object msgId = source.get("message_id"); + if (msgId instanceof Number) { + traceNumber = ((Number) msgId).intValue(); + } + } + } + + // Create Interaction object + if (!input.isEmpty() || !responseText.isEmpty()) { + return Interaction + .builder() + .id(id) + .conversationId(conversationId) + .createTime(createTime) + .updatedTime(updatedTime) + .input(input) + .response(responseText) + .origin(origin) + .promptTemplate(null) + .additionalInfo(null) + .parentInteractionId(parentInteractionId) + .traceNum(traceNumber) + .build(); + } + } + + return null; + } + + private List parseSearchResponseToTraces(String response) { + try { + SearchResponse searchResponse = parseSearchResponse(response); + return convertSearchHitsToTraces(searchResponse); + } catch (Exception e) { + log.error("Failed to parse trace response: " + response, e); + return new ArrayList<>(); + } + } + + /** + * Convert SearchResponse hits to Interaction list (for traces) + */ + private List convertSearchHitsToTraces(SearchResponse searchResponse) { + List traces = new ArrayList<>(); + if (searchResponse.getHits() != null && searchResponse.getHits().getHits() != null) { + for (SearchHit hit : searchResponse.getHits().getHits()) { + try { + Interaction trace = convertHitToInteraction(hit, true); + if (trace != null) { + traces.add(trace); + } + } catch (Exception e) { + log.warn("Failed to parse trace hit: " + hit.getId(), e); + } + } + } + return traces; + } + + /** + * Extract data map from ModelTensorOutput response + * This handles the ML Commons output wrapping that converts raw responses into ModelTensorOutput structures + * + * @param response The response object, typically MLTaskResponse + * @return Map containing the actual response data, or null if extraction fails + */ + protected static Map extractDataFromModelTensorOutput(Object response) { + if (response instanceof MLTaskResponse) { + MLTaskResponse taskResponse = (MLTaskResponse) response; + MLOutput mlOutput = taskResponse.getOutput(); + + if (mlOutput instanceof ModelTensorOutput) { + ModelTensorOutput tensorOutput = (ModelTensorOutput) mlOutput; + List outputs = tensorOutput.getMlModelOutputs(); + + if (outputs != null && !outputs.isEmpty()) { + List tensors = outputs.get(0).getMlModelTensors(); + if (tensors != null && !tensors.isEmpty()) { + return tensors.get(0).getDataAsMap(); + } + } + } + } + return null; + } + + /** + * Factory for creating RemoteAgenticConversationMemory instances + */ + public static class Factory implements Memory.Factory { + private ScriptService scriptService; + private ClusterService clusterService; + private Client client; + private NamedXContentRegistry xContentRegistry; + + public void init( + ScriptService scriptService, + ClusterService clusterService, + Client client, + NamedXContentRegistry xContentRegistry + ) { + this.scriptService = scriptService; + this.clusterService = clusterService; + this.client = client; + this.xContentRegistry = xContentRegistry; + } + + @Override + public void create(Map map, ActionListener listener) { + if (map == null || map.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Invalid input parameter for creating RemoteAgenticConversationMemory")); + return; + } + + String memoryId = (String) map.get(MEMORY_ID); + String name = (String) map.get(MEMORY_NAME); + String appType = (String) map.get(APP_TYPE); + String memoryContainerId = (String) map.get("memory_container_id"); + + // Extract inline connector metadata + String endpoint = (String) map.get("endpoint"); + String region = (String) map.get("region"); + Map credential = (Map) map.get("credential"); + String userId = (String) map.get("user_id"); + + create(name, memoryId, appType, memoryContainerId, endpoint, region, credential, userId, listener); + } + + public void create( + String name, + String memoryId, + String appType, + String memoryContainerId, + String endpoint, + String region, + Map credential, + String userId, + ActionListener listener + ) { + // Memory container ID is required + if (Strings.isNullOrEmpty(memoryContainerId)) { + listener + .onFailure( + new IllegalArgumentException( + "Memory container ID is required for RemoteAgenticConversationMemory. " + + "Please provide 'memory_container_id' in the agent configuration." + ) + ); + return; + } + + // Inline connector parameters are required + if (Strings.isNullOrEmpty(endpoint)) { + listener.onFailure(new IllegalArgumentException("Endpoint is required for RemoteAgenticConversationMemory")); + return; + } + + // Create inline connector + Connector connector = createInlineConnector(endpoint, region, credential); + + if (Strings.isEmpty(memoryId)) { + // Create new session using create_session action + createSessionInRemoteContainer(name, memoryContainerId, connector, ActionListener.wrap(sessionId -> { + create(sessionId, memoryContainerId, connector, userId, listener); + log.debug("Created session in remote memory container, session id: {}", sessionId); + }, e -> { + log.error("Failed to create session in remote memory container", e); + listener.onFailure(e); + })); + } else { + // Use existing session/memory ID + create(memoryId, memoryContainerId, connector, userId, listener); + } + } + + /** + * Create a new session in the remote memory container. + * + * This method uses the connector that already has all actions defined, + * including the create_session action with the proper name field. + */ + private void createSessionInRemoteContainer( + String summary, + String memoryContainerId, + Connector connector, + ActionListener listener + ) { + // The connector already has all actions defined, including create_session + // Log connector actions for debugging + if (log.isDebugEnabled()) { + if (connector.getActions() != null) { + log.debug("Connector has {} actions defined", connector.getActions().size()); + for (ConnectorAction action : connector.getActions()) { + log.debug("Action: name='{}', actionType='{}'", action.getName(), action.getActionType()); + } + } else { + log.debug("Connector has no actions defined!"); + } + } + + // Create executor for this connector + RemoteConnectorExecutor executor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); + executor.setScriptService(scriptService); + executor.setClusterService(clusterService); + executor.setClient(client); + executor.setXContentRegistry(xContentRegistry); + + // Prepare parameters for the action + Map inputParams = new HashMap<>(); + inputParams.put("memory_container_id", memoryContainerId); + + // Build request body for create_session + Map sessionBody = new HashMap<>(); + if (summary != null) { + sessionBody.put("summary", summary); + } + inputParams.put("body", GSON.toJson(sessionBody)); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(inputParams).build(); + + MLInput mlInput = MLInput.builder().algorithm(FunctionName.CONNECTOR).inputDataset(inputDataSet).build(); + + // Execute the action - the executor will find the create_session action by name + executor.executeAction("create_session", mlInput, ActionListener.wrap(response -> { + try { + String sessionId = null; + + // Extract data from ModelTensorOutput wrapper + Map dataMap = extractDataFromModelTensorOutput(response); + + // Log the create session response for debugging + log.debug("Create session response - extracted data: {}", dataMap != null ? GSON.toJson(dataMap) : "null"); + + if (dataMap != null && dataMap.containsKey("session_id")) { + sessionId = (String) dataMap.get("session_id"); + } + + if (sessionId != null) { + listener.onResponse(sessionId); + } else { + listener.onFailure(new RuntimeException("Failed to parse session_id from response")); + } + } catch (Exception e) { + listener.onFailure(e); + } + }, e -> { + log.error("Failed to create session via remote connector", e); + listener.onFailure(e); + })); + } + + public void create( + String memoryId, + String memoryContainerId, + Connector connector, + String userId, + ActionListener listener + ) { + listener + .onResponse( + new RemoteAgenticConversationMemory( + memoryId, + memoryContainerId, + userId, + connector, + scriptService, + clusterService, + client, + xContentRegistry + ) + ); + } + + /** + * Create inline connector from runtime parameters + */ + private Connector createInlineConnector(String endpoint, String region, Map credential) { + // Validate endpoint + if (!isValidEndpoint(endpoint)) { + throw new IllegalArgumentException("Invalid endpoint URL: " + endpoint); + } + + // Determine protocol based on credentials + String protocol = (region != null && credential != null && !credential.isEmpty()) ? "aws_sigv4" : "http"; + + // Build parameters + Map parameters = new HashMap<>(); + parameters.put("endpoint", endpoint); + if (region != null) { + parameters.put("region", region); + } + + // For AWS services, we need to specify the service name + // Extract from endpoint or default to "es" for OpenSearch/Elasticsearch + String serviceName = extractServiceName(endpoint); + parameters.put("service_name", serviceName); + + Map credentials = new HashMap<>(); + if (credential != null && !credential.isEmpty()) { + // Pass the credential map directly - it should already contain the correct structure + // For AWS SigV4: access_key, secret_key, session_token (optional) + // For other auth types: appropriate key-value pairs + credentials.putAll(credential); + } + + // Create Memory Container API actions + List actions = createMemoryContainerActions(); + + // Create appropriate connector based on protocol + Connector connector; + if ("aws_sigv4".equals(protocol)) { + // Use AwsConnector for AWS SigV4 + connector = AwsConnector + .awsConnectorBuilder() + .name("inline_remote_memory_connector") + .protocol(protocol) + .parameters(parameters) + .credential(credentials) + .actions(actions) + .build(); + } else { + // Use HttpConnector for plain HTTP + connector = HttpConnector + .builder() + .name("inline_remote_memory_connector") + .protocol(protocol) + .parameters(parameters) + .credential(credentials.isEmpty() ? null : credentials) + .actions(actions) + .build(); + } + + // Log connector configuration for debugging (mask sensitive credentials) + if (log.isDebugEnabled()) { + Map debugInfo = new HashMap<>(); + debugInfo.put("name", connector.getName()); + debugInfo.put("protocol", connector.getProtocol()); + debugInfo.put("parameters", connector.getParameters()); + debugInfo + .put( + "actions", + actions + .stream() + .map(a -> a.getName() != null ? a.getName() : a.getActionType().toString()) + .collect(Collectors.toList()) + ); + + // Log credential keys but not values for security + if (credentials != null && !credentials.isEmpty()) { + debugInfo.put("credential_keys", credentials.keySet()); + } + + log.debug("Created inline connector for RemoteAgenticConversationMemory: {}", GSON.toJson(debugInfo)); + } + + // Decrypt the connector credentials (for inline connectors, credentials are already plaintext) + // This populates the decryptedCredential field which AwsConnector methods depend on + connector + .decrypt( + ConnectorAction.ActionType.EXECUTE.name(), + (cred, tenantId) -> cred, // No-op function - credentials are already plaintext + null // No tenant ID for inline connectors + ); + + return connector; + } + + /** + * Create Memory Container API actions for the inline connector + */ + private List createMemoryContainerActions() { + List actions = new ArrayList<>(); + + // Create session action + actions + .add( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.EXECUTE) + .name("create_session") + .method("POST") + .url("${parameters.endpoint}/_plugins/_ml/memory_containers/${parameters.memory_container_id}/memories/sessions") + .headers(Map.of("Content-Type", "application/json")) + .requestBody("${parameters.body}") + .build() + ); + + // Add memory action + actions + .add( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.EXECUTE) + .name("add_memory") + .method("POST") + .url("${parameters.endpoint}/_plugins/_ml/memory_containers/${parameters.memory_container_id}/memories") + .headers(Map.of("Content-Type", "application/json")) + .requestBody("${parameters.body}") + .build() + ); + + // Search memories action + actions + .add( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.EXECUTE) + .name("search_memories") + .method("POST") + .url( + "${parameters.endpoint}/_plugins/_ml/memory_containers/${parameters.memory_container_id}/memories/${parameters.memory_type}/_search" + ) + .headers(Map.of("Content-Type", "application/json")) + .requestBody("${parameters.body}") + .build() + ); + + // Get memory action + actions + .add( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.EXECUTE) + .name("get_memory") + .method("GET") + .url( + "${parameters.endpoint}/_plugins/_ml/memory_containers/${parameters.memory_container_id}/memories/${parameters.memory_type}/${parameters.memory_id}" + ) + .headers(Map.of("Content-Type", "application/json")) + .build() + ); + + // Update memory action + actions + .add( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.EXECUTE) + .name("update_memory") + .method("PUT") + .url( + "${parameters.endpoint}/_plugins/_ml/memory_containers/${parameters.memory_container_id}/memories/${parameters.memory_type}/${parameters.memory_id}" + ) + .headers(Map.of("Content-Type", "application/json")) + .requestBody("${parameters.body}") + .build() + ); + + // Delete memory action (if needed in the future) + actions + .add( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.EXECUTE) + .name("delete_memory") + .method("DELETE") + .url( + "${parameters.endpoint}/_plugins/_ml/memory_containers/${parameters.memory_container_id}/memories/${parameters.memory_type}/${parameters.memory_id}" + ) + .headers(Map.of("Content-Type", "application/json")) + .build() + ); + + return actions; + } + + /** + * Helper method to extract service name from endpoint + */ + private String extractServiceName(String endpoint) { + // For AOSS endpoints: https://xxx.us-west-2.aoss.amazonaws.com + if (endpoint.contains(".aoss.amazonaws.com")) { + return "aoss"; + } + // For managed OpenSearch: https://xxx.us-west-2.es.amazonaws.com + if (endpoint.contains(".es.amazonaws.com")) { + return "es"; + } + // Default to es for OpenSearch/Elasticsearch service + return "es"; + } + + /** + * Validate endpoint URL + */ + private boolean isValidEndpoint(String endpoint) { + try { + // Basic validation - ensure it starts with http:// or https:// + return endpoint != null && (endpoint.startsWith("http://") || endpoint.startsWith("https://")); + } catch (Exception e) { + return false; + } + } + } + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java index bb1f86a70b..1a9b07fd6b 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/memory/AgenticConversationMemoryTest.java @@ -130,7 +130,7 @@ public void testSaveWithoutMemoryContainerId() { client, "test_conversation_id", null, // No memory container ID = should fail, - null + null ); ConversationIndexMessage message = ConversationIndexMessage diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java index cb5fe126ab..9f514b30e6 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java @@ -22,7 +22,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.MLTaskResponse; -import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; import org.opensearch.ml.common.transport.connector.MLExecuteConnectorAction; import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest; import org.opensearch.ml.engine.MLEngineClassLoader; @@ -61,7 +60,7 @@ public ExecuteConnectorTransportAction( EncryptorImpl encryptor, MLFeatureEnabledSetting mlFeatureEnabledSetting ) { - super(MLExecuteConnectorAction.NAME, transportService, actionFilters, MLConnectorDeleteRequest::new); + super(MLExecuteConnectorAction.NAME, transportService, actionFilters, MLExecuteConnectorRequest::new); this.client = client; this.clusterService = clusterService; this.scriptService = scriptService; @@ -74,31 +73,20 @@ public ExecuteConnectorTransportAction( @Override protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.fromActionRequest(request); - String connectorId = executeConnectorRequest.getConnectorId(); RemoteInferenceInputDataSet inputDataset = (RemoteInferenceInputDataSet) executeConnectorRequest.getMlInput().getInputDataset(); String connectorAction = ConnectorAction.ActionType.EXECUTE.name(); if (inputDataset.getParameters() != null && inputDataset.getParameters().get(CONNECTOR_ACTION_FIELD) != null) { connectorAction = inputDataset.getParameters().get(CONNECTOR_ACTION_FIELD); } + String connectorId = executeConnectorRequest.getConnectorId(); + if (MLIndicesHandler .doesMultiTenantIndexExist(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), ML_CONNECTOR_INDEX)) { String finalConnectorAction = connectorAction; ActionListener listener = ActionListener.wrap(connector -> { if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) { - // adding tenantID as null, because we are not implement multi-tenancy for this feature yet. - connector.decrypt(finalConnectorAction, (credential, tenantId) -> encryptor.decrypt(credential, null), null); - RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader - .initInstance(connector.getProtocol(), connector, Connector.class); - connectorExecutor.setConnectorPrivateIpEnabled(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()); - connectorExecutor.setScriptService(scriptService); - connectorExecutor.setClusterService(clusterService); - connectorExecutor.setClient(client); - connectorExecutor.setXContentRegistry(xContentRegistry); - connectorExecutor - .executeAction(finalConnectorAction, executeConnectorRequest.getMlInput(), ActionListener.wrap(taskResponse -> { - actionListener.onResponse(taskResponse); - }, e -> { actionListener.onFailure(e); })); + executeWithConnector(connector, finalConnectorAction, executeConnectorRequest, actionListener, true); } }, e -> { log.error("Failed to get connector " + connectorId, e); @@ -112,4 +100,31 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener, + boolean decryptWithEncryptor + ) { + if (decryptWithEncryptor) { + connector.decrypt(action, (credential, tenantId) -> encryptor.decrypt(credential, null), null); + } else { + connector.decrypt(action, (credential, tenantId) -> credential, null); + } + RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); + connectorExecutor.setConnectorPrivateIpEnabled(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()); + connectorExecutor.setScriptService(scriptService); + connectorExecutor.setClusterService(clusterService); + connectorExecutor.setClient(client); + connectorExecutor.setXContentRegistry(xContentRegistry); + connectorExecutor.executeAction(action, request.getMlInput(), ActionListener.wrap(response -> { + connector.removeCredential(); + listener.onResponse(response); + }, e -> { + connector.removeCredential(); + listener.onFailure(e); + })); + } + } diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java index e05126322e..cde70c3bb3 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/TransportUpdateMemoryContainerAction.java @@ -286,21 +286,22 @@ private void validateAndCreateIndices( ) { String tenantId = container.getTenantId(); // Validate LLM model using helper - MemoryContainerModelValidator.validateLlmModel(tenantId, config.getLlmId(), mlModelManager, client, ActionListener.wrap(llmValid -> { - // LLM validated, now validate embedding model - MemoryContainerModelValidator - .validateEmbeddingModel( - tenantId, - config.getEmbeddingModelId(), - config.getEmbeddingModelType(), - mlModelManager, - client, - ActionListener.wrap(embeddingValid -> { - // Both models validated, proceed to shared index validation and creation - validateSharedIndexAndCreateIndices(container, config, updateFields, memoryContainerId, listener); - }, listener::onFailure) - ); - }, listener::onFailure)); + MemoryContainerModelValidator + .validateLlmModel(tenantId, config.getLlmId(), mlModelManager, client, ActionListener.wrap(llmValid -> { + // LLM validated, now validate embedding model + MemoryContainerModelValidator + .validateEmbeddingModel( + tenantId, + config.getEmbeddingModelId(), + config.getEmbeddingModelType(), + mlModelManager, + client, + ActionListener.wrap(embeddingValid -> { + // Both models validated, proceed to shared index validation and creation + validateSharedIndexAndCreateIndices(container, config, updateFields, memoryContainerId, listener); + }, listener::onFailure) + ); + }, listener::onFailure)); } /** diff --git a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerModelValidator.java b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerModelValidator.java index ecc557a50e..faffdc5dad 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerModelValidator.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/MemoryContainerModelValidator.java @@ -38,7 +38,13 @@ public final class MemoryContainerModelValidator { * @param client The OpenSearch client * @param listener Action listener that receives true on success, or error on failure */ - public static void validateLlmModel(String tenantId, String llmId, MLModelManager modelManager, Client client, ActionListener listener) { + public static void validateLlmModel( + String tenantId, + String llmId, + MLModelManager modelManager, + Client client, + ActionListener listener + ) { if (llmId == null) { listener.onResponse(true); return; diff --git a/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java index 9e5c973135..df79e95cc8 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java @@ -422,7 +422,7 @@ private void runConnector(Connector connector, String actionName, MLInput mlInpu if (connector == null) { throw new IllegalArgumentException("connector is null"); } - //TODO: current we only support internal connector inside memory container in OASIS. The tenant id is same with container's. + // TODO: current we only support internal connector inside memory container in OASIS. The tenant id is same with container's. // We should check tenant id in future if we use a standalone connector inside memory container. String connectorTenantId = connector.getTenantId(); // adding tenantID as null, because we are not implement multi-tenancy for this feature yet. diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index a9e5000def..54fe5f0d26 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -281,6 +281,7 @@ import org.opensearch.ml.engine.memory.AgenticConversationMemory; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.ml.engine.memory.MLMemoryManager; +import org.opensearch.ml.engine.memory.RemoteAgenticConversationMemory; import org.opensearch.ml.engine.tools.AgentTool; import org.opensearch.ml.engine.tools.ConnectorTool; import org.opensearch.ml.engine.tools.IndexInsightTool; @@ -883,6 +884,10 @@ public Collection createComponents( agenticConversationMemoryFactory.init(client); memoryFactoryMap.put(AgenticConversationMemory.TYPE, agenticConversationMemoryFactory); + RemoteAgenticConversationMemory.Factory remoteAgenticConversationMemoryFactory = new RemoteAgenticConversationMemory.Factory(); + remoteAgenticConversationMemoryFactory.init(scriptService, clusterService, client, xContentRegistry); + memoryFactoryMap.put(RemoteAgenticConversationMemory.TYPE, remoteAgenticConversationMemoryFactory); + MLAgentExecutor agentExecutor = new MLAgentExecutor( client, sdkClient, From d06bc8dab5c71bde3253857822d71f427a098cc4 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Fri, 7 Nov 2025 16:25:42 -0800 Subject: [PATCH 29/58] not stashing context if index is not system index (#4407) Signed-off-by: Dhrubo Saha --- .../ml/engine/indices/MLIndicesHandler.java | 60 +++++++++++++++++-- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index 4eb567802c..98c91f75ba 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -31,6 +31,7 @@ import java.io.UncheckedIOException; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; import org.opensearch.OpenSearchWrapperException; @@ -175,19 +176,31 @@ public String getMapping(String mappingPath) { public void createSessionMemoryDataIndex(String indexName, MemoryConfiguration configuration, ActionListener listener) { String indexMappings = getMapping(ML_MEMORY_SESSION_INDEX_MAPPING_PATH); Map indexSettings = configuration.getMemoryIndexMapping(SESSION_INDEX); - initIndexIfAbsent(indexName, StringUtils.toJson(indexMappings), indexSettings, 1, listener); + if (configuration.isUseSystemIndex()) { + initIndexIfAbsent(indexName, StringUtils.toJson(indexMappings), indexSettings, 1, listener); + } else { + initIndexWithContext(indexName, StringUtils.toJson(indexMappings), indexSettings, 1, listener); + } } public void createWorkingMemoryDataIndex(String indexName, MemoryConfiguration configuration, ActionListener listener) { String indexMappings = getMapping(ML_WORKING_MEMORY_INDEX_MAPPING_PATH); Map indexSettings = configuration.getMemoryIndexMapping(WORKING_MEMORY_INDEX); - initIndexIfAbsent(indexName, StringUtils.toJson(indexMappings), indexSettings, 1, listener); + if (configuration.isUseSystemIndex()) { + initIndexIfAbsent(indexName, StringUtils.toJson(indexMappings), indexSettings, 1, listener); + } else { + initIndexWithContext(indexName, StringUtils.toJson(indexMappings), indexSettings, 1, listener); + } } public void createLongTermMemoryHistoryIndex(String indexName, MemoryConfiguration configuration, ActionListener listener) { String indexMappings = getMapping(ML_LONG_MEMORY_HISTORY_INDEX_MAPPING_PATH); Map indexSettings = configuration.getMemoryIndexMapping(LONG_TERM_MEMORY_HISTORY_INDEX); - initIndexIfAbsent(indexName, StringUtils.toJson(indexMappings), indexSettings, 1, listener); + if (configuration.isUseSystemIndex()) { + initIndexIfAbsent(indexName, StringUtils.toJson(indexMappings), indexSettings, 1, listener); + } else { + initIndexWithContext(indexName, StringUtils.toJson(indexMappings), indexSettings, 1, listener); + } } /** @@ -271,7 +284,11 @@ public void createLongTermMemoryIndex( } // Initialize index with mapping and settings - initIndexIfAbsent(indexName, indexMappings, indexSettings, 1, listener); + if (memoryConfig.isUseSystemIndex()) { + initIndexIfAbsent(indexName, indexMappings, indexSettings, 1, listener); + } else { + initIndexWithContext(indexName, indexMappings, indexSettings, 1, listener); + } } catch (Exception e) { log.error("Failed to create long-term memory index", e); listener.onFailure(e); @@ -287,6 +304,41 @@ public void initIndexIfAbsent(String indexName, String mapping, Integer version, initIndexIfAbsent(indexName, mapping, null, version, listener); } + public void initIndexWithContext( + String indexName, + String mapping, + Map indexSettings, + Integer version, + ActionListener listener + ) { + log.info("Using initIndexWithContext method to create index: {}", indexName); + try { + ActionListener actionListener = ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + log.info("create index:{}", indexName); + listener.onResponse(true); + } else { + listener.onResponse(false); + } + }, e -> { + if (e instanceof ResourceAlreadyExistsException + || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { + log.info("Skip creating the Index:{} that is already created by another parallel request", indexName); + listener.onResponse(true); + } else { + log.error("Failed to create index {}", indexName, e); + listener.onFailure(e); + } + }); + CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(mapping, XContentType.JSON); + request.settings(Objects.requireNonNullElse(indexSettings, DEFAULT_INDEX_SETTINGS)); + client.admin().indices().create(request, actionListener); + } catch (Exception e) { + log.error("Failed to init index {}", indexName, e); + listener.onFailure(e); + } + } + public void initIndexIfAbsent( String indexName, String mapping, From bc60a50d666a9035d29b564e11f95d2814f051a4 Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Sun, 9 Nov 2025 18:19:36 -0800 Subject: [PATCH 30/58] Update the POST_TOOL hook emit saving to agentic memory (#4408) * Fix POST_TOOL hook interaction updates and add tenant ID support Signed-off-by: Mingshi Liu - Fix POST_TOOL hook to return full ContextManagerContext like PRE_LLM hook - Update MLChatAgentRunner to properly handle interaction updates from POST_TOOL hook - Ensure interactions list and tmpParameters.INTERACTIONS stay synchronized - Add tenant ID support to MLPredictionTaskRequest in ModelGuardrail and SummarizationManager Signed-off-by: Mingshi Liu * fix error message escaping Signed-off-by: Mingshi Liu * consolicate post_hook logic Signed-off-by: Mingshi Liu --------- Signed-off-by: Mingshi Liu --- .../ml/common/model/ModelGuardrail.java | 6 ++- .../ml/engine/agents/AgentContextUtil.java | 23 ++++------ .../algorithms/agent/MLChatAgentRunner.java | 46 +++++++++---------- .../contextmanager/SummarizationManager.java | 9 +++- 4 files changed, 45 insertions(+), 39 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java index 9b1b6c6a81..b32e6471d7 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java @@ -7,6 +7,7 @@ import static java.util.concurrent.TimeUnit.SECONDS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; @@ -125,13 +126,16 @@ public Boolean validate(String in, Map parameters) { guardrailModelParams.put("response_filter", responseFilter); } log.info("Guardrail resFilter: {}", responseFilter); + String tenantId = parameters != null ? parameters.get(TENANT_ID_FIELD) : null; ActionRequest request = new MLPredictionTaskRequest( modelId, RemoteInferenceMLInput .builder() .algorithm(FunctionName.REMOTE) .inputDataset(RemoteInferenceInputDataSet.builder().parameters(guardrailModelParams).build()) - .build() + .build(), + null, + tenantId ); client.execute(MLPredictionTaskAction.INSTANCE, request, new LatchedActionListener(actionListener, latch)); try { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java index 83180d9551..0715262e8c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java @@ -104,36 +104,33 @@ public static ContextManagerContext buildContextManagerContext( return builder.build(); } - public static Object emitPostToolHook( + public static ContextManagerContext emitPostToolHook( Object toolOutput, Map parameters, List toolSpecs, Memory memory, HookRegistry hookRegistry ) { + ContextManagerContext context = buildContextManagerContextForToolOutput( + StringUtils.toJson(toolOutput), + parameters, + toolSpecs, + memory + ); + if (hookRegistry != null) { try { if (toolOutput == null) { log.warn("Tool output is null, skipping POST_TOOL hook"); - return null; + return context; } - ContextManagerContext context = buildContextManagerContextForToolOutput( - StringUtils.toJson(toolOutput), - parameters, - toolSpecs, - memory - ); EnhancedPostToolEvent event = new EnhancedPostToolEvent(null, null, context, new HashMap<>()); hookRegistry.emit(event); - - Object processedOutput = extractProcessedToolOutput(context); - return processedOutput != null ? processedOutput : toolOutput; } catch (Exception e) { log.error("Failed to emit POST_TOOL hook event", e); - return toolOutput; } } - return toolOutput; + return context; } public static ContextManagerContext emitPreLLMHook( diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 21bdf26fd5..04cb78b52c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -486,7 +486,7 @@ private void runReAct( ((ActionListener) nextStepListener).onResponse(res); } } else { - // filteredOutput is the POST Tool output + // output is now the processed output from POST_TOOL hook in runTool Object filteredOutput = filterToolOutput(lastToolParams, output); addToolOutputToAddtionalInfo(toolSpecMap, lastAction, additionalInfo, filteredOutput); @@ -499,6 +499,7 @@ private void runReAct( ); scratchpadBuilder.append(toolResponse).append("\n\n"); + // Save trace with processed output saveTraceData( memory, "ReAct", @@ -680,26 +681,23 @@ private static void runTool( try { String finalAction = action; ActionListener toolListener = ActionListener.wrap(r -> { - if (functionCalling != null) { - String outputResponse = parseResponse(filterToolOutput(toolParams, r)); + // Emit POST_TOOL hook event - common for all tool executions + List postToolSpecs = new ArrayList<>(toolSpecMap.values()); + ContextManagerContext contextAfterPostTool = AgentContextUtil + .emitPostToolHook(r, tmpParameters, postToolSpecs, null, hookRegistry); - // Emit POST_TOOL hook event after tool execution and process current tool - // output - List postToolSpecs = new ArrayList<>(toolSpecMap.values()); - String outputResponseAfterHook = AgentContextUtil - .emitPostToolHook(outputResponse, tmpParameters, postToolSpecs, null, hookRegistry) - .toString(); + // Extract processed output from POST_TOOL hook + String processedToolOutput = contextAfterPostTool.getParameters().get("_current_tool_output"); + Object processedOutput = processedToolOutput != null ? processedToolOutput : r; + if (functionCalling != null) { + String outputResponse = parseResponse(filterToolOutput(toolParams, processedOutput)); List> toolResults = List - .of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", outputResponseAfterHook))); + .of(Map.of(TOOL_CALL_ID, toolCallId, TOOL_RESULT, Map.of("text", outputResponse))); List llmMessages = functionCalling.supply(toolResults); - // TODO: support multiple tool calls at the same time so that multiple - // LLMMessages can be generated here + // TODO: support multiple tool calls at the same time so that multiple LLMMessages can be generated here interactions.add(llmMessages.getFirst().getResponse()); } else { - // Emit POST_TOOL hook event for non-function calling path - List postToolSpecs = new ArrayList<>(toolSpecMap.values()); - Object processedOutput = AgentContextUtil.emitPostToolHook(r, tmpParameters, postToolSpecs, null, hookRegistry); interactions .add( substitute( @@ -709,25 +707,25 @@ private static void runTool( ) ); } - nextStepListener.onResponse(r); + nextStepListener.onResponse(processedOutput); }, e -> { interactions .add( substitute( tmpParameters.get(INTERACTION_TEMPLATE_TOOL_RESPONSE), - Map.of(TOOL_CALL_ID, toolCallId, "tool_response", "Tool " + action + " failed: " + e.getMessage()), + Map + .of( + TOOL_CALL_ID, + toolCallId, + "tool_response", + "Tool " + action + " failed: " + StringUtils.processTextDoc(e.getMessage()) + ), INTERACTIONS_PREFIX ) ); nextStepListener .onResponse( - String - .format( - Locale.ROOT, - "Failed to run the tool %s with the error message %s.", - finalAction, - e.getMessage().replaceAll("\\n", "\n") - ) + String.format(Locale.ROOT, "Failed to run the tool %s with the error message %s.", finalAction, e.getMessage()) ); }); if (tools.get(action) instanceof MLModelTool) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java index 75128e266a..38ea1b4aac 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SummarizationManager.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine.algorithms.contextmanager; import static java.lang.Math.min; +import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.common.FunctionName.REMOTE; import static org.opensearch.ml.common.utils.StringUtils.processTextDoc; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; @@ -193,7 +194,13 @@ protected void executeSummarization( MLInput mlInput = MLInput.builder().algorithm(REMOTE).inputDataset(inputDataset).build(); // Create prediction request - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().modelId(modelId).mlInput(mlInput).build(); + String tenantId = (String) context.getParameter(TENANT_ID_FIELD); + MLPredictionTaskRequest request = MLPredictionTaskRequest + .builder() + .modelId(modelId) + .mlInput(mlInput) + .tenantId(tenantId) + .build(); // Execute prediction ActionListener listener = ActionListener.wrap(response -> { From addda969369b99d43bf7902f2c596a9d0de47589 Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Mon, 10 Nov 2025 02:05:54 -0800 Subject: [PATCH 31/58] wrap remote agent memory config into an object to avoid override (#4411) --- .../engine/algorithms/agent/AgentUtils.java | 62 ++++++++++++++----- .../RemoteAgenticConversationMemory.java | 10 +-- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 4b3cf23511..0a1fe8fdd2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -1044,22 +1044,52 @@ public static Map createMemoryParams( memoryParams.put(TENANT_ID_FIELD, mlAgent.getTenantId()); } if (requestParameters != null) { - String endpointParam = requestParameters.get(ENDPOINT_FIELD); - if (!Strings.isNullOrEmpty(endpointParam)) { - memoryParams.put(ENDPOINT_FIELD, endpointParam); - } - String regionParam = requestParameters.get(HttpConnector.REGION_FIELD); - if (!Strings.isNullOrEmpty(regionParam)) { - memoryParams.put(HttpConnector.REGION_FIELD, regionParam); - } - Map credential = parseStringMapParameter(requestParameters.get(CREDENTIAL_FIELD), CREDENTIAL_FIELD); - if (credential != null && !credential.isEmpty()) { - memoryParams.put(CREDENTIAL_FIELD, credential); - } - // Extract user_id if provided - String userIdParam = requestParameters.get("user_id"); - if (!Strings.isNullOrEmpty(userIdParam)) { - memoryParams.put("user_id", userIdParam); + // Check if parameters are wrapped in remote_agent_memory_configuration + String remoteMemoryConfigStr = requestParameters.get("remote_agent_memory_configuration"); + if (!Strings.isNullOrEmpty(remoteMemoryConfigStr)) { + // Parse the remote_agent_memory_configuration JSON + try ( + XContentParser parser = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, remoteMemoryConfigStr) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Map remoteMemoryConfig = parser.map(); + + // Extract memory_container_id + String memoryContainerIdParam = (String) remoteMemoryConfig.get(MEMORY_CONTAINER_ID_FIELD); + if (!Strings.isNullOrEmpty(memoryContainerIdParam)) { + memoryParams.put(MEMORY_CONTAINER_ID_FIELD, memoryContainerIdParam); + } + + // Extract endpoint + String endpointParam = (String) remoteMemoryConfig.get(ENDPOINT_FIELD); + if (!Strings.isNullOrEmpty(endpointParam)) { + memoryParams.put(ENDPOINT_FIELD, endpointParam); + } + + // Extract region + String regionParam = (String) remoteMemoryConfig.get(HttpConnector.REGION_FIELD); + if (!Strings.isNullOrEmpty(regionParam)) { + memoryParams.put(HttpConnector.REGION_FIELD, regionParam); + } + + // Extract credential + Object credentialObj = remoteMemoryConfig.get(CREDENTIAL_FIELD); + if (credentialObj instanceof Map) { + Map credential = (Map) credentialObj; + if (!credential.isEmpty()) { + memoryParams.put(CREDENTIAL_FIELD, credential); + } + } + + // Extract user_id if provided + String userIdParam = (String) remoteMemoryConfig.get("user_id"); + if (!Strings.isNullOrEmpty(userIdParam)) { + memoryParams.put("user_id", userIdParam); + } + } catch (Exception e) { + log.error("Failed to parse remote_agent_memory_configuration", e); + } } memoryParams.put(TENANT_ID_FIELD, mlAgent.getTenantId()); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java index c0c74b0d23..4a4cb45508 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java @@ -1201,12 +1201,14 @@ private String extractServiceName(String endpoint) { if (endpoint.contains(".aoss.amazonaws.com")) { return "aoss"; } - // For managed OpenSearch: https://xxx.us-west-2.es.amazonaws.com - if (endpoint.contains(".es.amazonaws.com")) { + // For managed OpenSearch (production, staging, integration) + if (endpoint.contains(".es.amazonaws.com") + || endpoint.contains(".es-staging.amazonaws.com") + || endpoint.contains(".es-integ.amazonaws.com")) { return "es"; } - // Default to es for OpenSearch/Elasticsearch service - return "es"; + // Default to aoss for other OpenSearch services + return "aoss"; } /** From 20eb9bcb4a37e6d910207e1013c927f22e5367dc Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Tue, 11 Nov 2025 17:20:05 -0800 Subject: [PATCH 32/58] Fix 403 during execute connector (#4414) --- .../engine/algorithms/agent/AgentUtils.java | 33 +++++++----- .../RemoteAgenticConversationMemory.java | 51 ++++++++++++++++++- .../ExecuteConnectorTransportAction.java | 5 +- 3 files changed, 73 insertions(+), 16 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 0a1fe8fdd2..38a6a4ed7c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -1044,37 +1044,37 @@ public static Map createMemoryParams( memoryParams.put(TENANT_ID_FIELD, mlAgent.getTenantId()); } if (requestParameters != null) { - // Check if parameters are wrapped in remote_agent_memory_configuration - String remoteMemoryConfigStr = requestParameters.get("remote_agent_memory_configuration"); - if (!Strings.isNullOrEmpty(remoteMemoryConfigStr)) { - // Parse the remote_agent_memory_configuration JSON + // Check if parameters are wrapped in memory_configuration + String memoryConfigStr = requestParameters.get("memory_configuration"); + if (!Strings.isNullOrEmpty(memoryConfigStr)) { + // Parse the memory_configuration JSON try ( XContentParser parser = JsonXContent.jsonXContent - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, remoteMemoryConfigStr) + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, memoryConfigStr) ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Map remoteMemoryConfig = parser.map(); + Map memoryConfig = parser.map(); // Extract memory_container_id - String memoryContainerIdParam = (String) remoteMemoryConfig.get(MEMORY_CONTAINER_ID_FIELD); + String memoryContainerIdParam = (String) memoryConfig.get(MEMORY_CONTAINER_ID_FIELD); if (!Strings.isNullOrEmpty(memoryContainerIdParam)) { memoryParams.put(MEMORY_CONTAINER_ID_FIELD, memoryContainerIdParam); } // Extract endpoint - String endpointParam = (String) remoteMemoryConfig.get(ENDPOINT_FIELD); + String endpointParam = (String) memoryConfig.get(ENDPOINT_FIELD); if (!Strings.isNullOrEmpty(endpointParam)) { memoryParams.put(ENDPOINT_FIELD, endpointParam); } // Extract region - String regionParam = (String) remoteMemoryConfig.get(HttpConnector.REGION_FIELD); + String regionParam = (String) memoryConfig.get(HttpConnector.REGION_FIELD); if (!Strings.isNullOrEmpty(regionParam)) { memoryParams.put(HttpConnector.REGION_FIELD, regionParam); } // Extract credential - Object credentialObj = remoteMemoryConfig.get(CREDENTIAL_FIELD); + Object credentialObj = memoryConfig.get(CREDENTIAL_FIELD); if (credentialObj instanceof Map) { Map credential = (Map) credentialObj; if (!credential.isEmpty()) { @@ -1082,13 +1082,22 @@ public static Map createMemoryParams( } } + // Check for direct roleArn field - if present, override credential map + String roleArnParam = (String) memoryConfig.get("roleArn"); + if (!Strings.isNullOrEmpty(roleArnParam)) { + // Override credential with roleArn + Map roleArnCredential = new HashMap<>(); + roleArnCredential.put("roleArn", roleArnParam); + memoryParams.put(CREDENTIAL_FIELD, roleArnCredential); + } + // Extract user_id if provided - String userIdParam = (String) remoteMemoryConfig.get("user_id"); + String userIdParam = (String) memoryConfig.get("user_id"); if (!Strings.isNullOrEmpty(userIdParam)) { memoryParams.put("user_id", userIdParam); } } catch (Exception e) { - log.error("Failed to parse remote_agent_memory_configuration", e); + log.error("Failed to parse memory_configuration", e); } } memoryParams.put(TENANT_ID_FIELD, mlAgent.getTenantId()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java index 4a4cb45508..5a56ef68dc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/memory/RemoteAgenticConversationMemory.java @@ -1032,6 +1032,9 @@ private Connector createInlineConnector(String endpoint, String region, Map actions = createMemoryContainerActions(); @@ -1046,6 +1049,7 @@ private Connector createInlineConnector(String endpoint, String region, Map cred, // No-op function - credentials are already plaintext - null // No tenant ID for inline connectors + (cred, tenant) -> cred, // No-op function - credentials are already plaintext + tenantId ); return connector; @@ -1211,6 +1216,48 @@ private String extractServiceName(String endpoint) { return "aoss"; } + /** + * Extract dummy tenant ID from role ARN for AOSS services. + * Essentially we only need the front part (account ID) as client ID when in AOSS + * + * @param serviceName The AWS service name (aoss or es) + * @param credential The credential map that may contain roleArn + * @return Tenant ID in format "account:role" for AOSS, null for ES or if roleArn not present + * + * Example: + * - Input: serviceName="aoss", roleArn="arn:aws:iam::123456789012:role/role-name" + * - Output: "123456789012:role" + */ + private String extractTenantIdFromRoleArn(String serviceName, Map credential) { + // Return null for ES service + if (!"aoss".equals(serviceName)) { + return null; + } + + // Check if credential map exists and contains roleArn + if (credential == null || !credential.containsKey("roleArn")) { + return null; + } + + String roleArn = credential.get("roleArn"); + if (Strings.isNullOrEmpty(roleArn)) { + return null; + } + + // Expected format: arn:aws:iam::{account}:role/{role-name} + try { + String[] parts = roleArn.split(":"); + if (parts.length >= 5 && "role".equals(parts[4])) { + String account = parts[3]; + return account + ":role"; + } + } catch (Exception e) { + log.error("Failed to parse roleArn: {}", roleArn, e); + } + + return null; + } + /** * Validate endpoint URL */ diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java index 9f514b30e6..808ad2156c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java @@ -107,10 +107,11 @@ private void executeWithConnector( ActionListener listener, boolean decryptWithEncryptor ) { + String connectorTenantId = connector.getTenantId(); if (decryptWithEncryptor) { - connector.decrypt(action, (credential, tenantId) -> encryptor.decrypt(credential, null), null); + connector.decrypt(action, (credential, tenantId) -> encryptor.decrypt(credential, tenantId), connectorTenantId); } else { - connector.decrypt(action, (credential, tenantId) -> credential, null); + connector.decrypt(action, (credential, tenantId) -> credential, connectorTenantId); } RemoteConnectorExecutor connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); connectorExecutor.setConnectorPrivateIpEnabled(mlFeatureEnabledSetting.isConnectorPrivateIpEnabled()); From df74b5151277dcb68e7a40a64893355affc862e5 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Mon, 22 Sep 2025 14:51:51 +0800 Subject: [PATCH 33/58] init max step summary Signed-off-by: Jiaru Jiang --- build.gradle | 2 + .../algorithms/agent/MLChatAgentRunner.java | 152 +++++++++++++++++- .../algorithms/agent/PromptTemplate.java | 3 + .../agent/MLChatAgentRunnerTest.java | 91 ++++++++++- 4 files changed, 244 insertions(+), 4 deletions(-) diff --git a/build.gradle b/build.gradle index cfc4419471..7334235127 100644 --- a/build.gradle +++ b/build.gradle @@ -44,6 +44,7 @@ buildscript { configurations.all { resolutionStrategy { force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // for spotless transitive dependency CVE (for 3.26.100) + force("com.google.errorprone:error_prone_annotations:2.18.0") } } } @@ -95,6 +96,7 @@ subprojects { configurations.all { // Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades. resolutionStrategy.force "com.google.guava:guava:32.1.3-jre" + resolutionStrategy.force "com.google.errorprone:error_prone_annotations:2.18.0" resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' resolutionStrategy.force "io.netty:netty-buffer:${versions.netty}" resolutionStrategy.force "io.netty:netty-codec:${versions.netty}" diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 04cb78b52c..d4d0fae239 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -36,6 +36,7 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; import static org.opensearch.ml.engine.tools.ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY; +import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE; import java.security.PrivilegedActionException; import java.util.ArrayList; @@ -125,9 +126,12 @@ public class MLChatAgentRunner implements MLAgentRunner { public static final String DATETIME_FORMAT_FIELD = "datetime_format"; public static final String SYSTEM_PROMPT_FIELD = "system_prompt"; private static final String DEFAULT_SYSTEM_PROMPT = "You are an helpful assistant."; // empty system prompt + public static final String SUMMARIZE_WHEN_MAX_ITERATION = "summarize_when_max_iteration"; private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; + private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = + "Agent reached maximum iterations (%d) without completing the task. Here's a summary of the steps taken:\n\n%s"; private Client client; private Settings settings; @@ -352,6 +356,8 @@ private void runReAct( StringBuilder scratchpadBuilder = new StringBuilder(); final List interactions = new CopyOnWriteArrayList<>(); + List interactions = new CopyOnWriteArrayList<>(); + List executionSteps = new CopyOnWriteArrayList<>(); StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); @@ -411,6 +417,17 @@ private void runReAct( lastActionInput.set(actionInput); lastToolSelectionResponse.set(thoughtResponse); + // Record execution step for summary + if (thought != null && !"null".equals(thought) && !thought.trim().isEmpty()) { + executionSteps.add(String.format("Thought: %s", thought.trim())); + } + if (action != null && !"null".equals(action) && !action.trim().isEmpty()) { + String actionDesc = actionInput != null && !"null".equals(actionInput) + ? String.format("Action: %s(%s)", action.trim(), actionInput.trim()) + : String.format("Action: %s", action.trim()); + executionSteps.add(actionDesc); + } + traceTensors .add( ModelTensors @@ -445,7 +462,11 @@ private void runReAct( additionalInfo, lastThought, maxIterations, - tools + tools, + tmpParameters, + executionSteps, + llm, + tenantId ); return; } @@ -500,6 +521,10 @@ private void runReAct( scratchpadBuilder.append(toolResponse).append("\n\n"); // Save trace with processed output + // Record tool result for summary + String outputSummary = outputToOutputString(filteredOutput); + executionSteps.add(String.format("Result: %s", outputSummary)); + saveTraceData( memory, "ReAct", @@ -550,7 +575,11 @@ private void runReAct( additionalInfo, lastThought, maxIterations, - tools + tools, + tmpParameters, + executionSteps, + llm, + tenantId ); return; } @@ -760,7 +789,7 @@ private static void runTool( * some special parameters like SCRATCHPAD_NOTES_KEY, * some new llmToolTmpParameters produced by the tool run can opt to be copied * back to tmpParameters to share across tools in the same interaction - * + * * @param tmpParameters * @param llmToolTmpParameters */ @@ -969,6 +998,65 @@ public static void returnFinalResponse( } private void handleMaxIterationsReached( + String sessionId, + ActionListener listener, + String question, + String parentInteractionId, + boolean verbose, + boolean traceDisabled, + List traceTensors, + ConversationIndexMemory conversationIndexMemory, + AtomicInteger traceNumber, + Map additionalInfo, + AtomicReference lastThought, + int maxIterations, + Map tools, + Map parameters, + List executionSteps, + LLMSpec llmSpec, + String tenantId + ) { + boolean shouldSummarize = Boolean.parseBoolean(parameters.getOrDefault(SUMMARIZE_WHEN_MAX_ITERATION, "false")); + + if (shouldSummarize && !executionSteps.isEmpty()) { + generateLLMSummary(executionSteps, llmSpec, tenantId, ActionListener.wrap(summary -> { + String incompleteResponse = String.format(MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); + sendFinalAnswer( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + incompleteResponse + ); + cleanUpResource(tools); + }, e -> { log.warn("Failed to generate LLM summary", e); })); + } else { + // Use traditional approach + sendTraditionalMaxIterationsResponse( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + lastThought, + maxIterations, + tools + ); + } + } + + private void sendTraditionalMaxIterationsResponse( String sessionId, ActionListener listener, String question, @@ -1002,6 +1090,64 @@ private void handleMaxIterationsReached( cleanUpResource(tools); } + void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener listener) { + if (stepsSummary == null || stepsSummary.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty")); + return; + } + + try { + Map summaryParams = new HashMap<>(); + if (llmSpec.getParameters() != null) { + summaryParams.putAll(llmSpec.getParameters()); + } + String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepsSummary)); + summaryParams.put("inputs", summaryPrompt); + summaryParams.put("prompt", summaryPrompt); + summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); + + ActionRequest request = new MLPredictionTaskRequest( + llmSpec.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) + .build(), + null, + tenantId + ); + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { + String summary = extractSummaryFromResponse(response); + if (summary != null) { + listener.onResponse(summary); + } else { + listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); + } + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private String extractSummaryFromResponse(MLTaskResponse response) { + try { + String outputString = outputToOutputString(response.getOutput()); + if (outputString != null && !outputString.trim().isEmpty()) { + Map dataMap = gson.fromJson(outputString, Map.class); + if (dataMap.containsKey("response")) { + String summary = String.valueOf(dataMap.get("response")); + if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { + return summary.trim(); + } + } + } + return null; + } catch (Exception e) { + log.warn("Failed to extract summary from response", e); + return null; + } + } + private void saveMessage( Memory memory, String question, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java index 67b29c3557..f1b1147a7d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java @@ -140,4 +140,7 @@ public class PromptTemplate { - Avoid making assumptions and relying on implicit knowledge. - Your response must be self-contained and ready for the planner to use without modification. Never end with a question. - Break complex searches into simpler queries when appropriate."""; + + public static final String SUMMARY_PROMPT_TEMPLATE = + "Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s\n\nPlease respond in the following JSON format:\n{\"response\": \"your summary here\"}"; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 1855c6fbd0..12625f8205 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -681,7 +681,7 @@ public void testToolThrowException() { .when(firstTool) .run(Mockito.anyMap(), toolListenerCaptor.capture()); // Run the MLChatAgentRunner - mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); + mlChatAgentRunner.run(mlAgent, params, agentActionListener); // Verify that the tool's run method was called verify(firstTool).run(any(), any()); @@ -1172,6 +1172,95 @@ public void testConstructLLMParams_DefaultValues() { Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE)); } + @Test + public void testMaxIterationsWithSummaryEnabled() { + // Create LLM spec with max_iteration = 1 to simplify test + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + // Reset and setup fresh mocks + Mockito.reset(client); + // First call: LLM response without final_answer to trigger max iterations + // Second call: Summary LLM response + Mockito + .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to analyze the data", "action", FIRST_TOOL))) + .doAnswer(getLLMAnswer(ImmutableMap.of("response", "Summary: Analysis step was attempted"))) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "true"); + + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify response is captured + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + // Verify the response contains summary message + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertTrue( + response.startsWith("Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps taken:") + ); + assertTrue(response.contains("Summary: Analysis step was attempted")); + } + + @Test + public void testMaxIterationsWithSummaryDisabled() { + // Create LLM spec with max_iteration = 1 and summary disabled + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + // Reset client mock for this test + Mockito.reset(client); + // Mock LLM response that doesn't contain final_answer + Mockito + .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL))) + .when(client) + .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "false"); + + mlChatAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify response is captured + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + // Verify the response contains traditional max iterations message + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response); + } + @Test public void testCreateMemoryAdapter_ConversationIndex() { // Test that ConversationIndex memory type returns ConversationIndexMemory From 0381e052dcb6c9f86c1108cb8de246914412e6fb Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Mon, 22 Sep 2025 15:01:30 +0800 Subject: [PATCH 34/58] fix:recover build.gradle Signed-off-by: Jiaru Jiang --- build.gradle | 2 -- 1 file changed, 2 deletions(-) diff --git a/build.gradle b/build.gradle index 7334235127..cfc4419471 100644 --- a/build.gradle +++ b/build.gradle @@ -44,7 +44,6 @@ buildscript { configurations.all { resolutionStrategy { force("org.eclipse.platform:org.eclipse.core.runtime:3.29.0") // for spotless transitive dependency CVE (for 3.26.100) - force("com.google.errorprone:error_prone_annotations:2.18.0") } } } @@ -96,7 +95,6 @@ subprojects { configurations.all { // Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades. resolutionStrategy.force "com.google.guava:guava:32.1.3-jre" - resolutionStrategy.force "com.google.errorprone:error_prone_annotations:2.18.0" resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' resolutionStrategy.force "io.netty:netty-buffer:${versions.netty}" resolutionStrategy.force "io.netty:netty-codec:${versions.netty}" From 49161196adf9f7c70f625637884cf6b6f56c76cf Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 10:12:14 +0800 Subject: [PATCH 35/58] add:increase test coverage Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 2 +- .../agent/MLChatAgentRunnerTest.java | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index d4d0fae239..729cfec00e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1129,7 +1129,7 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenan } } - private String extractSummaryFromResponse(MLTaskResponse response) { + public String extractSummaryFromResponse(MLTaskResponse response) { try { String outputString = outputToOutputString(response.getOutput()); if (outputString != null && !outputString.trim().isEmpty()) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 12625f8205..0d6f8be474 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1392,4 +1392,32 @@ public void testSimpleChatHistoryTemplateEngine() { assertTrue("Should contain assistant message", chatHistory.contains("Assistant: It's sunny today!")); assertTrue("Should contain system context", chatHistory.contains("[Context] Weather data retrieved from API")); } + + @Test + public void testExtractSummaryFromResponse() { + MLTaskResponse response = MLTaskResponse.builder() + .output(ModelTensorOutput.builder() + .mlModelOutputs(Arrays.asList( + ModelTensors.builder() + .mlModelTensors(Arrays.asList( + ModelTensor.builder() + .dataAsMap(ImmutableMap.of("response", "Valid summary text")) + .build())) + .build())) + .build()) + .build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals("Valid summary text", result); + } + + @Test + public void testGenerateLLMSummaryWithNullSteps() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); + ActionListener listener = Mockito.mock(ActionListener.class); + + mlChatAgentRunner.generateLLMSummary(null, llmSpec, "tenant", listener); + + verify(listener).onFailure(any(IllegalArgumentException.class)); + } } From d762e159dd6420097d3a22513f382a471de5db6b Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 10:14:24 +0800 Subject: [PATCH 36/58] fix:spotlessApply Signed-off-by: Jiaru Jiang --- .../agent/MLChatAgentRunnerTest.java | 31 +++++++++++++------ 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 0d6f8be474..e51893b2ee 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1395,16 +1395,27 @@ public void testSimpleChatHistoryTemplateEngine() { @Test public void testExtractSummaryFromResponse() { - MLTaskResponse response = MLTaskResponse.builder() - .output(ModelTensorOutput.builder() - .mlModelOutputs(Arrays.asList( - ModelTensors.builder() - .mlModelTensors(Arrays.asList( - ModelTensor.builder() - .dataAsMap(ImmutableMap.of("response", "Valid summary text")) - .build())) - .build())) - .build()) + MLTaskResponse response = MLTaskResponse + .builder() + .output( + ModelTensorOutput + .builder() + .mlModelOutputs( + Arrays + .asList( + ModelTensors + .builder() + .mlModelTensors( + Arrays + .asList( + ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "Valid summary text")).build() + ) + ) + .build() + ) + ) + .build() + ) .build(); String result = mlChatAgentRunner.extractSummaryFromResponse(response); From 8a52d92036858783f4a106900b7bb3a8d137fe49 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 15:01:35 +0800 Subject: [PATCH 37/58] fix:use traceTensor Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 38 ++++++------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 729cfec00e..28c6a06959 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -130,8 +130,7 @@ public class MLChatAgentRunner implements MLAgentRunner { private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; - private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = - "Agent reached maximum iterations (%d) without completing the task. Here's a summary of the steps taken:\n\n%s"; + private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE + ". Here's a summary of the steps taken:\n\n%s"; private Client client; private Settings settings; @@ -357,8 +356,6 @@ private void runReAct( StringBuilder scratchpadBuilder = new StringBuilder(); final List interactions = new CopyOnWriteArrayList<>(); List interactions = new CopyOnWriteArrayList<>(); - List executionSteps = new CopyOnWriteArrayList<>(); - StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); tmpParameters.put(PROMPT, newPrompt.get()); @@ -417,17 +414,6 @@ private void runReAct( lastActionInput.set(actionInput); lastToolSelectionResponse.set(thoughtResponse); - // Record execution step for summary - if (thought != null && !"null".equals(thought) && !thought.trim().isEmpty()) { - executionSteps.add(String.format("Thought: %s", thought.trim())); - } - if (action != null && !"null".equals(action) && !action.trim().isEmpty()) { - String actionDesc = actionInput != null && !"null".equals(actionInput) - ? String.format("Action: %s(%s)", action.trim(), actionInput.trim()) - : String.format("Action: %s", action.trim()); - executionSteps.add(actionDesc); - } - traceTensors .add( ModelTensors @@ -464,7 +450,6 @@ private void runReAct( maxIterations, tools, tmpParameters, - executionSteps, llm, tenantId ); @@ -520,11 +505,6 @@ private void runReAct( ); scratchpadBuilder.append(toolResponse).append("\n\n"); - // Save trace with processed output - // Record tool result for summary - String outputSummary = outputToOutputString(filteredOutput); - executionSteps.add(String.format("Result: %s", outputSummary)); - saveTraceData( memory, "ReAct", @@ -577,7 +557,6 @@ private void runReAct( maxIterations, tools, tmpParameters, - executionSteps, llm, tenantId ); @@ -1012,14 +991,13 @@ private void handleMaxIterationsReached( int maxIterations, Map tools, Map parameters, - List executionSteps, LLMSpec llmSpec, String tenantId ) { boolean shouldSummarize = Boolean.parseBoolean(parameters.getOrDefault(SUMMARIZE_WHEN_MAX_ITERATION, "false")); - if (shouldSummarize && !executionSteps.isEmpty()) { - generateLLMSummary(executionSteps, llmSpec, tenantId, ActionListener.wrap(summary -> { + if (shouldSummarize && !traceTensors.isEmpty()) { + generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { String incompleteResponse = String.format(MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); sendFinalAnswer( sessionId, @@ -1090,7 +1068,7 @@ private void sendTraditionalMaxIterationsResponse( cleanUpResource(tools); } - void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener listener) { + void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener listener) { if (stepsSummary == null || stepsSummary.isEmpty()) { listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty")); return; @@ -1101,7 +1079,13 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String tenan if (llmSpec.getParameters() != null) { summaryParams.putAll(llmSpec.getParameters()); } - String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepsSummary)); + + // Convert ModelTensors to strings before joining + List stepStrings = new ArrayList<>(); + for (ModelTensors tensor : stepsSummary) { + stepStrings.add(outputToOutputString(tensor)); + } + String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); summaryParams.put("prompt", summaryPrompt); summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); From f0ed0a510f4b50036fe115187fdb1d582fd64368 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 15:13:38 +0800 Subject: [PATCH 38/58] fix:String.format() Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 28c6a06959..5a9934d61d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -998,7 +998,7 @@ private void handleMaxIterationsReached( if (shouldSummarize && !traceTensors.isEmpty()) { generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { - String incompleteResponse = String.format(MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); + String incompleteResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); sendFinalAnswer( sessionId, listener, From a06f3002bbea40101cd422165c0191f75ec98495 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 15:53:50 +0800 Subject: [PATCH 39/58] fix:String.format() Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 5a9934d61d..9e74a0444f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1085,7 +1085,7 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String for (ModelTensors tensor : stepsSummary) { stepStrings.add(outputToOutputString(tensor)); } - String summaryPrompt = String.format(SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); + String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); summaryParams.put("prompt", summaryPrompt); summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); From d6a073380a32a1e4fe8616e290c4db84bf62fce2 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 16:17:45 +0800 Subject: [PATCH 40/58] fix:reuse sendTraditionalMaxIterationsResponse method Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 9e74a0444f..cf9018365e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -998,8 +998,9 @@ private void handleMaxIterationsReached( if (shouldSummarize && !traceTensors.isEmpty()) { generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { - String incompleteResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); - sendFinalAnswer( + String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); + AtomicReference summaryThought = new AtomicReference<>(summaryResponse); + sendTraditionalMaxIterationsResponse( sessionId, listener, question, @@ -1010,12 +1011,12 @@ private void handleMaxIterationsReached( conversationIndexMemory, traceNumber, additionalInfo, - incompleteResponse + summaryThought, + 0, // 不使用 maxIterations 格式化,直接使用 summaryResponse + tools ); - cleanUpResource(tools); }, e -> { log.warn("Failed to generate LLM summary", e); })); } else { - // Use traditional approach sendTraditionalMaxIterationsResponse( sessionId, listener, @@ -1049,9 +1050,16 @@ private void sendTraditionalMaxIterationsResponse( int maxIterations, Map tools ) { - String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) - ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) - : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + String incompleteResponse; + if (maxIterations == 0) { + // 直接使用 lastThought 中的完整消息(用于摘要情况) + incompleteResponse = lastThought.get(); + } else { + // 传统格式化(用于普通情况) + incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) + ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + } sendFinalAnswer( sessionId, listener, @@ -1087,7 +1095,7 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String } String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); - summaryParams.put("prompt", summaryPrompt); + summaryParams.put(PROMPT, summaryPrompt); summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); ActionRequest request = new MLPredictionTaskRequest( From ea708b26df08a8a2fa8889bf1ed87b8f60382604 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 23 Sep 2025 16:22:39 +0800 Subject: [PATCH 41/58] fix:remove useless comment Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index cf9018365e..6371746f8a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1012,7 +1012,7 @@ private void handleMaxIterationsReached( traceNumber, additionalInfo, summaryThought, - 0, // 不使用 maxIterations 格式化,直接使用 summaryResponse + 0, tools ); }, e -> { log.warn("Failed to generate LLM summary", e); })); @@ -1052,10 +1052,8 @@ private void sendTraditionalMaxIterationsResponse( ) { String incompleteResponse; if (maxIterations == 0) { - // 直接使用 lastThought 中的完整消息(用于摘要情况) incompleteResponse = lastThought.get(); } else { - // 传统格式化(用于普通情况) incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); From 74349e39f98b8df12df2b0045bf610eb6931b4a9 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Fri, 10 Oct 2025 13:39:03 +0800 Subject: [PATCH 42/58] fix: delete stop Signed-off-by: Jiaru Jiang --- .../opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java | 1 - 1 file changed, 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 6371746f8a..56b1388012 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1094,7 +1094,6 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); summaryParams.put(PROMPT, summaryPrompt); - summaryParams.putIfAbsent("stop", gson.toJson(new String[] { "\n\n", "```" })); ActionRequest request = new MLPredictionTaskRequest( llmSpec.getModelId(), From 4afcd85ee3586e320f315a42766ef7b125410e59 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Fri, 10 Oct 2025 15:06:51 +0800 Subject: [PATCH 43/58] fix: refactor Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 56b1388012..d8aecd18e1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -999,7 +999,6 @@ private void handleMaxIterationsReached( if (shouldSummarize && !traceTensors.isEmpty()) { generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); - AtomicReference summaryThought = new AtomicReference<>(summaryResponse); sendTraditionalMaxIterationsResponse( sessionId, listener, @@ -1011,12 +1010,18 @@ private void handleMaxIterationsReached( conversationIndexMemory, traceNumber, additionalInfo, - summaryThought, - 0, + summaryResponse, tools ); - }, e -> { log.warn("Failed to generate LLM summary", e); })); + }, e -> { + log.error("Failed to generate LLM summary", e); + listener.onFailure(e); + cleanUpResource(tools); + })); } else { + String response = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) + ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); sendTraditionalMaxIterationsResponse( sessionId, listener, @@ -1028,8 +1033,7 @@ private void handleMaxIterationsReached( conversationIndexMemory, traceNumber, additionalInfo, - lastThought, - maxIterations, + response, tools ); } @@ -1046,18 +1050,9 @@ private void sendTraditionalMaxIterationsResponse( Memory memory, AtomicInteger traceNumber, Map additionalInfo, - AtomicReference lastThought, - int maxIterations, + String response, Map tools ) { - String incompleteResponse; - if (maxIterations == 0) { - incompleteResponse = lastThought.get(); - } else { - incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) - ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) - : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); - } sendFinalAnswer( sessionId, listener, @@ -1069,7 +1064,7 @@ private void sendTraditionalMaxIterationsResponse( memory, traceNumber, additionalInfo, - incompleteResponse + response ); cleanUpResource(tools); } @@ -1132,8 +1127,8 @@ public String extractSummaryFromResponse(MLTaskResponse response) { } return null; } catch (Exception e) { - log.warn("Failed to extract summary from response", e); - return null; + log.error("Failed to extract summary from response", e); + throw new RuntimeException("Failed to extract summary from response", e); } } From 60613526c1e0814b1632de34042e26dfd32867fb Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Fri, 10 Oct 2025 15:55:20 +0800 Subject: [PATCH 44/58] fix: json serialization Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 24 ++++++++++++++----- .../agent/MLChatAgentRunnerTest.java | 17 +++++++++++-- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index d8aecd18e1..56cc910bdf 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1084,7 +1084,15 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String // Convert ModelTensors to strings before joining List stepStrings = new ArrayList<>(); for (ModelTensors tensor : stepsSummary) { - stepStrings.add(outputToOutputString(tensor)); + if (tensor != null && tensor.getMlModelTensors() != null) { + for (ModelTensor modelTensor : tensor.getMlModelTensors()) { + if (modelTensor.getResult() != null) { + stepStrings.add(modelTensor.getResult()); + } else if (modelTensor.getDataAsMap() != null && modelTensor.getDataAsMap().containsKey("response")) { + stepStrings.add(String.valueOf(modelTensor.getDataAsMap().get("response"))); + } + } + } } String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put("inputs", summaryPrompt); @@ -1117,12 +1125,16 @@ public String extractSummaryFromResponse(MLTaskResponse response) { try { String outputString = outputToOutputString(response.getOutput()); if (outputString != null && !outputString.trim().isEmpty()) { - Map dataMap = gson.fromJson(outputString, Map.class); - if (dataMap.containsKey("response")) { - String summary = String.valueOf(dataMap.get("response")); - if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { - return summary.trim(); + try { + Map dataMap = gson.fromJson(outputString, Map.class); + if (dataMap.containsKey("response")) { + String summary = String.valueOf(dataMap.get("response")); + if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { + return summary.trim(); + } } + } catch (Exception jsonException) { + return outputString.trim(); } } return null; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index e51893b2ee..2c77cb236f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1188,11 +1188,24 @@ public void testMaxIterationsWithSummaryEnabled() { // Reset and setup fresh mocks Mockito.reset(client); + Mockito.reset(firstTool); + when(firstTool.getName()).thenReturn(FIRST_TOOL); + when(firstTool.validate(Mockito.anyMap())).thenReturn(true); + Mockito.doAnswer(generateToolResponse("First tool response")).when(firstTool).run(Mockito.anyMap(), any()); + // First call: LLM response without final_answer to trigger max iterations - // Second call: Summary LLM response + // Second call: Summary LLM response with result field instead of dataAsMap Mockito .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to analyze the data", "action", FIRST_TOOL))) - .doAnswer(getLLMAnswer(ImmutableMap.of("response", "Summary: Analysis step was attempted"))) + .doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("Summary: Analysis step was attempted").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); + listener.onResponse(mlTaskResponse); + return null; + }) .when(client) .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); From 52c949efab16c15325db1b8f0b63d85b2d129d58 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Mon, 13 Oct 2025 14:00:22 +0800 Subject: [PATCH 45/58] fix: parameter Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 56cc910bdf..cd81f6e426 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1095,8 +1095,8 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String } } String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); - summaryParams.put("inputs", summaryPrompt); summaryParams.put(PROMPT, summaryPrompt); + summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, ""); ActionRequest request = new MLPredictionTaskRequest( llmSpec.getModelId(), From becdad79b19a022c4d5e4bbec767382e61fa9652 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Wed, 29 Oct 2025 11:45:17 +0800 Subject: [PATCH 46/58] delete:summarize_when_max_iteration Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index cd81f6e426..3502a9999d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -355,7 +355,6 @@ private void runReAct( StringBuilder scratchpadBuilder = new StringBuilder(); final List interactions = new CopyOnWriteArrayList<>(); - List interactions = new CopyOnWriteArrayList<>(); StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); AtomicReference newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); tmpParameters.put(PROMPT, newPrompt.get()); @@ -994,9 +993,6 @@ private void handleMaxIterationsReached( LLMSpec llmSpec, String tenantId ) { - boolean shouldSummarize = Boolean.parseBoolean(parameters.getOrDefault(SUMMARIZE_WHEN_MAX_ITERATION, "false")); - - if (shouldSummarize && !traceTensors.isEmpty()) { generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); sendTraditionalMaxIterationsResponse( @@ -1018,25 +1014,6 @@ private void handleMaxIterationsReached( listener.onFailure(e); cleanUpResource(tools); })); - } else { - String response = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) - ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) - : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); - sendTraditionalMaxIterationsResponse( - sessionId, - listener, - question, - parentInteractionId, - verbose, - traceDisabled, - traceTensors, - conversationIndexMemory, - traceNumber, - additionalInfo, - response, - tools - ); - } } private void sendTraditionalMaxIterationsResponse( From 209565122990b1548aaf0d192fcfb9803f29b3ae Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Wed, 29 Oct 2025 14:08:42 +0800 Subject: [PATCH 47/58] fix: unit test Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 58 ++++++++++++------- .../agent/MLChatAgentRunnerTest.java | 19 +++--- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 3502a9999d..92babb5d1c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -993,27 +993,43 @@ private void handleMaxIterationsReached( LLMSpec llmSpec, String tenantId ) { - generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> { - String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary); - sendTraditionalMaxIterationsResponse( - sessionId, - listener, - question, - parentInteractionId, - verbose, - traceDisabled, - traceTensors, - conversationIndexMemory, - traceNumber, - additionalInfo, - summaryResponse, - tools - ); - }, e -> { - log.error("Failed to generate LLM summary", e); - listener.onFailure(e); - cleanUpResource(tools); - })); + ActionListener responseListener = ActionListener.wrap(response -> { + sendTraditionalMaxIterationsResponse( + sessionId, + listener, + question, + parentInteractionId, + verbose, + traceDisabled, + traceTensors, + conversationIndexMemory, + traceNumber, + additionalInfo, + response, + tools + ); + }, listener::onFailure); + + generateLLMSummary( + traceTensors, + llmSpec, + tenantId, + ActionListener + .wrap( + summary -> responseListener + .onResponse(String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary)), + e -> { + log.error("Failed to generate LLM summary, using fallback strategy", e); + String fallbackResponse = (lastThought.get() != null + && !lastThought.get().isEmpty() + && !"null".equals(lastThought.get())) + ? String + .format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) + : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); + responseListener.onResponse(fallbackResponse); + } + ) + ); } private void sendTraditionalMaxIterationsResponse( diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 2c77cb236f..0fc1fa32aa 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1211,8 +1211,6 @@ public void testMaxIterationsWithSummaryEnabled() { Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); - params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "true"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); // Verify response is captured @@ -1234,7 +1232,7 @@ public void testMaxIterationsWithSummaryEnabled() { @Test public void testMaxIterationsWithSummaryDisabled() { - // Create LLM spec with max_iteration = 1 and summary disabled + // Create LLM spec with max_iteration = 1 LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); final MLAgent mlAgent = MLAgent @@ -1248,15 +1246,16 @@ public void testMaxIterationsWithSummaryDisabled() { // Reset client mock for this test Mockito.reset(client); - // Mock LLM response that doesn't contain final_answer - Mockito - .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL))) - .when(client) - .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + // First call: LLM response without final_answer to trigger max iterations + // Second call: Summary LLM fails + Mockito.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL))).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("LLM summary generation failed")); + return null; + }).when(client).execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); - params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "false"); mlChatAgentRunner.run(mlAgent, params, agentActionListener); @@ -1269,7 +1268,7 @@ public void testMaxIterationsWithSummaryDisabled() { List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); assertEquals(1, agentOutput.size()); - // Verify the response contains traditional max iterations message + // Verify the response uses fallback strategy with last thought String response = (String) agentOutput.get(0).getDataAsMap().get("response"); assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response); } From 46a332e0a19d790efb262130dbe6b36a5ba39a34 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Wed, 29 Oct 2025 17:37:35 +0800 Subject: [PATCH 48/58] fix: summary message Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 9 +++++---- .../engine/algorithms/agent/MLChatAgentRunnerTest.java | 5 ++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 92babb5d1c..06e7853d58 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -130,7 +130,8 @@ public class MLChatAgentRunner implements MLAgentRunner { private static final String DEFAULT_MAX_ITERATIONS = "10"; private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; - private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE + ". Here's a summary of the steps taken:\n\n%s"; + private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE + + ". Here's a summary of the steps completed so far:\n\n%s"; private Client client; private Settings settings; @@ -1103,11 +1104,11 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String ); client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { String summary = extractSummaryFromResponse(response); - if (summary != null) { - listener.onResponse(summary); - } else { + if (summary == null) { listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); + return; } + listener.onResponse(summary); }, listener::onFailure)); } catch (Exception e) { listener.onFailure(e); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 0fc1fa32aa..47625974f4 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1225,7 +1225,10 @@ public void testMaxIterationsWithSummaryEnabled() { // Verify the response contains summary message String response = (String) agentOutput.get(0).getDataAsMap().get("response"); assertTrue( - response.startsWith("Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps taken:") + response + .startsWith( + "Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps completed so far:" + ) ); assertTrue(response.contains("Summary: Analysis step was attempted")); } From ac7390a58e27db26c5c299635c57dc38125804b7 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Wed, 29 Oct 2025 23:27:55 +0800 Subject: [PATCH 49/58] fix: summary prompt Signed-off-by: Jiaru Jiang --- build.gradle | 2 +- .../engine/algorithms/agent/MLChatAgentRunner.java | 12 +----------- .../ml/engine/algorithms/agent/PromptTemplate.java | 2 +- .../algorithms/agent/MLChatAgentRunnerTest.java | 7 +------ .../ml/helper/ModelAccessControlHelper.java | 10 ++++++++++ 5 files changed, 14 insertions(+), 19 deletions(-) diff --git a/build.gradle b/build.gradle index cfc4419471..87d59b0299 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ buildscript { ext { opensearch_group = "org.opensearch" isSnapshot = "true" == System.getProperty("build.snapshot", "true") - opensearch_version = System.getProperty("opensearch.version", "3.4.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.3.0-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") asm_version = "9.7" diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 06e7853d58..2b5f3e742b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1119,17 +1119,7 @@ public String extractSummaryFromResponse(MLTaskResponse response) { try { String outputString = outputToOutputString(response.getOutput()); if (outputString != null && !outputString.trim().isEmpty()) { - try { - Map dataMap = gson.fromJson(outputString, Map.class); - if (dataMap.containsKey("response")) { - String summary = String.valueOf(dataMap.get("response")); - if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) { - return summary.trim(); - } - } - } catch (Exception jsonException) { - return outputString.trim(); - } + return outputString.trim(); } return null; } catch (Exception e) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java index f1b1147a7d..cbeb3fa54d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/PromptTemplate.java @@ -142,5 +142,5 @@ public class PromptTemplate { - Break complex searches into simpler queries when appropriate."""; public static final String SUMMARY_PROMPT_TEMPLATE = - "Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s\n\nPlease respond in the following JSON format:\n{\"response\": \"your summary here\"}"; + "Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s"; } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 47625974f4..e76acd6cc3 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1420,12 +1420,7 @@ public void testExtractSummaryFromResponse() { .asList( ModelTensors .builder() - .mlModelTensors( - Arrays - .asList( - ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "Valid summary text")).build() - ) - ) + .mlModelTensors(Arrays.asList(ModelTensor.builder().result("Valid summary text").build())) .build() ) ) diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java index adcc5d196e..43788fb4d9 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java @@ -288,6 +288,16 @@ public void checkModelGroupPermission(MLModelGroup mlModelGroup, User user, Acti } } + /** + * Checks whether to utilize new ResourceAuthz + * @param resourceType for which to decide whether to use resource authz + * @return true if the resource-sharing feature is enabled, false otherwise. + */ + public static boolean shouldUseResourceAuthz(String resourceType) { + var client = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); + return client != null; + } + public boolean skipModelAccessControl(User user) { // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin // Case 2: If Security is enabled and filter is disabled, proceed with search as From fce546f2b40c44c51c7919c2f2510097551615cc Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Thu, 30 Oct 2025 10:27:25 +0800 Subject: [PATCH 50/58] fix: configuration file Signed-off-by: Jiaru Jiang --- build.gradle | 2 +- .../java/org/opensearch/ml/helper/ModelAccessControlHelper.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/build.gradle b/build.gradle index 87d59b0299..cfc4419471 100644 --- a/build.gradle +++ b/build.gradle @@ -11,7 +11,7 @@ buildscript { ext { opensearch_group = "org.opensearch" isSnapshot = "true" == System.getProperty("build.snapshot", "true") - opensearch_version = System.getProperty("opensearch.version", "3.3.0-SNAPSHOT") + opensearch_version = System.getProperty("opensearch.version", "3.4.0-SNAPSHOT") buildVersionQualifier = System.getProperty("build.version_qualifier", "") asm_version = "9.7" diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java index 43788fb4d9..77fe7a370b 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java @@ -295,7 +295,7 @@ public void checkModelGroupPermission(MLModelGroup mlModelGroup, User user, Acti */ public static boolean shouldUseResourceAuthz(String resourceType) { var client = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); - return client != null; + return client != null && client.isFeatureEnabledForType(resourceType); } public boolean skipModelAccessControl(User user) { From 6795f7ac5df388ca9b0dc4665bc146656d0409d3 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Thu, 30 Oct 2025 17:20:14 +0800 Subject: [PATCH 51/58] add: planner max step summary Signed-off-by: Jiaru Jiang --- .../MLPlanExecuteAndReflectAgentRunner.java | 145 ++++++++++++++--- .../agent/MLChatAgentRunnerTest.java | 40 +++-- ...LPlanExecuteAndReflectAgentRunnerTest.java | 148 +++++++++++++++++- 3 files changed, 303 insertions(+), 30 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index f70ddc9f50..a2fe844c5e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -34,6 +34,7 @@ import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.FINAL_RESULT_RESPONSE_INSTRUCTIONS; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLANNER_RESPONSIBILITY; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.PLAN_EXECUTE_REFLECT_RESPONSE_FORMAT; +import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE; import java.util.ArrayList; import java.util.HashMap; @@ -388,27 +389,8 @@ private void executePlanningLoop( int maxSteps = Integer.parseInt(allParams.getOrDefault(MAX_STEPS_EXECUTED_FIELD, DEFAULT_MAX_STEPS_EXECUTED)); String parentInteractionId = allParams.get(MLAgentExecutor.PARENT_INTERACTION_ID); - // completedSteps stores the step and its result, hence divide by 2 to find total steps completed - // on reaching max iteration, update parent interaction question with last executed step rather than task to allow continue using - // memory_id if (stepsExecuted >= maxSteps) { - String finalResult = String - .format( - "Max Steps Limit Reached. Use memory_id with same task to restart. \n " - + "Last executed step: %s, \n " - + "Last executed step result: %s", - completedSteps.get(completedSteps.size() - 2), - completedSteps.getLast() - ); - saveAndReturnFinalResult( - memory, - parentInteractionId, - allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD), - allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD), - finalResult, - null, - finalListener - ); + handleMaxStepsReached(llm, allParams, completedSteps, memory, parentInteractionId, finalListener); return; } MLPredictionTaskRequest request; @@ -762,4 +744,127 @@ static List createModelTensors( Map getTaskUpdates() { return taskUpdates; } + + private void handleMaxStepsReached( + LLMSpec llm, + Map allParams, + List completedSteps, + Memory memory, + String parentInteractionId, + ActionListener finalListener + ) { + int maxSteps = Integer.parseInt(allParams.getOrDefault(MAX_STEPS_EXECUTED_FIELD, DEFAULT_MAX_STEPS_EXECUTED)); + log.info("[SUMMARY] Max steps reached. Completed steps: {}", completedSteps.size()); + + ActionListener responseListener = ActionListener.wrap(response -> { + saveAndReturnFinalResult( + (ConversationIndexMemory) memory, + parentInteractionId, + allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD), + allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD), + response, + null, + finalListener + ); + }, finalListener::onFailure); + + generateSummary(llm, completedSteps, allParams.get(TENANT_ID_FIELD), ActionListener.wrap(summary -> { + log.info("Summary generated successfully"); + responseListener + .onResponse( + String.format("Max Steps Limit (%d) Reached. Here's a summary of the steps completed so far:\n\n%s", maxSteps, summary) + ); + }, e -> { + log.error("Summary generation failed, using fallback", e); + String fallbackResult = completedSteps.isEmpty() || completedSteps.size() < 2 + ? String.format("Max Steps Limit (%d) Reached. Use memory_id with same task to restart.", maxSteps) + : String + .format( + "Max Steps Limit (%d) Reached. Use memory_id with same task to restart. \n " + + "Last executed step: %s, \n " + + "Last executed step result: %s", + maxSteps, + completedSteps.get(completedSteps.size() - 2), + completedSteps.getLast() + ); + responseListener.onResponse(fallbackResult); + })); + } + + private void generateSummary(LLMSpec llmSpec, List completedSteps, String tenantId, ActionListener listener) { + if (completedSteps == null || completedSteps.isEmpty()) { + listener.onFailure(new IllegalArgumentException("Completed steps cannot be null or empty")); + return; + } + + try { + Map summaryParams = new HashMap<>(); + if (llmSpec.getParameters() != null) { + summaryParams.putAll(llmSpec.getParameters()); + } + + String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", completedSteps)); + summaryParams.put(PROMPT_FIELD, summaryPrompt); + summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, SUMMARY_PROMPT_TEMPLATE); + + MLPredictionTaskRequest request = new MLPredictionTaskRequest( + llmSpec.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) + .build(), + null, + tenantId + ); + + client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { + String summary = extractSummaryFromResponse(response); + if (summary == null || summary.trim().isEmpty()) { + log.error("Extracted summary is empty"); + listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); + return; + } + listener.onResponse(summary); + }, listener::onFailure)); + } catch (Exception e) { + listener.onFailure(e); + } + } + + private String extractSummaryFromResponse(MLTaskResponse response) { + try { + ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); + if (output != null && output.getMlModelOutputs() != null && !output.getMlModelOutputs().isEmpty()) { + ModelTensors tensors = output.getMlModelOutputs().getFirst(); + if (tensors != null && tensors.getMlModelTensors() != null && !tensors.getMlModelTensors().isEmpty()) { + ModelTensor tensor = tensors.getMlModelTensors().getFirst(); + if (tensor.getResult() != null) { + return tensor.getResult().trim(); + } + if (tensor.getDataAsMap() != null) { + Map dataMap = tensor.getDataAsMap(); + if (dataMap.containsKey(RESPONSE_FIELD)) { + return String.valueOf(dataMap.get(RESPONSE_FIELD)).trim(); + } + if (dataMap.containsKey("output")) { + Object outputObj = JsonPath.read(dataMap, "$.output.message.content[0].text"); + if (outputObj != null) { + return String.valueOf(outputObj).trim(); + } + } + } + log + .error( + "Summary generate error. No result/response field. Available: {}", + tensor.getDataAsMap() != null ? tensor.getDataAsMap().keySet() : "null" + ); + } + } + return null; + } catch (Exception e) { + log.error("Summary extraction failed", e); + throw new RuntimeException("Failed to extract summary from response", e); + } + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index e76acd6cc3..e8da40b428 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1031,11 +1031,19 @@ public void testMaxIterationsReached() { .tools(Arrays.asList(firstToolSpec)) .build(); - // Mock LLM response that doesn't contain final_answer to force max iterations - Mockito - .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "", "action", FIRST_TOOL))) - .when(client) - .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + // Reset client mock for this test + Mockito.reset(client); + // First call: LLM response without final_answer to force max iterations + // Second call: Summary LLM response + Mockito.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "", "action", FIRST_TOOL))).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("The agent attempted to use the first tool").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); @@ -1051,9 +1059,15 @@ public void testMaxIterationsReached() { List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); assertEquals(1, agentOutput.size()); - // Verify the response contains max iterations message + // Verify the response contains max iterations message with summary String response = (String) agentOutput.get(0).getDataAsMap().get("response"); - assertEquals("Agent reached maximum iterations (1) without completing the task", response); + assertTrue( + response + .startsWith( + "Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps completed so far:" + ) + ); + assertTrue(response.contains("The agent attempted to use the first tool")); } @Test @@ -1070,9 +1084,17 @@ public void testMaxIterationsReachedWithValidThought() { .tools(Arrays.asList(firstToolSpec)) .build(); - // Mock LLM response with valid thought + // Reset client mock for this test + Mockito.reset(client); + // First call: LLM response with valid thought to trigger max iterations + // Second call: Summary LLM fails to trigger fallback Mockito .doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the first tool", "action", FIRST_TOOL))) + .doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("LLM summary generation failed")); + return null; + }) .when(client) .execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); @@ -1090,7 +1112,7 @@ public void testMaxIterationsReachedWithValidThought() { List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); assertEquals(1, agentOutput.size()); - // Verify the response contains the last valid thought instead of max iterations message + // Verify the response contains the last valid thought (fallback when summary fails) String response = (String) agentOutput.get(0).getDataAsMap().get("response"); assertEquals( "Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the first tool", diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index 985661a9c9..b730c139d6 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -383,7 +383,153 @@ public void testMessageHistoryLimits() { assertEquals("3", executorParams.get("message_history_limit")); } - // ToDo: add test case for when max steps is reached + @Test + public void testMaxStepsReachedWithSummary() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("Summary of work done").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached. Here's a summary of the steps completed so far:")); + assertTrue(finalResponse.contains("Summary of work done")); + } + + @Test + public void testMaxStepsReachedWithSummaryGeneration() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor.builder().result("Generated summary of completed steps").build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached. Here's a summary of the steps completed so far:")); + assertTrue(finalResponse.contains("Generated summary of completed steps")); + } + + @Test + public void testMaxStepsReachedWithSummaryFailure() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("Summary generation failed")); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached")); + } + + @Test + public void testMaxStepsReachedWithEmptyCompletedSteps() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + listener.onResponse(Collections.emptyList()); + return null; + }).when(conversationIndexMemory).getMessages(any(), anyInt()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("Completed steps cannot be null or empty")); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object response = objectCaptor.getValue(); + assertTrue(response instanceof ModelTensorOutput); + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) response; + + List mlModelOutputs = modelTensorOutput.getMlModelOutputs(); + ModelTensor responseTensor = mlModelOutputs.get(1).getMlModelTensors().get(0); + String finalResponse = (String) responseTensor.getDataAsMap().get("response"); + assertTrue(finalResponse.contains("Max Steps Limit (0) Reached")); + } private MLAgent createMLAgentWithTools() { LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build(); From 4815ab4ae3107fa34fcbc0ebf3b59a378b19de37 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Fri, 31 Oct 2025 10:29:44 +0800 Subject: [PATCH 52/58] fix: system prompt Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 2b5f3e742b..9ee06d4fd4 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -1090,7 +1090,7 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String } String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); summaryParams.put(PROMPT, summaryPrompt); - summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, ""); + summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, SUMMARY_PROMPT_TEMPLATE); ActionRequest request = new MLPredictionTaskRequest( llmSpec.getModelId(), From ab38b1f188e8b53d12ba2867349ebd1e170e2aa5 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 4 Nov 2025 14:17:15 +0800 Subject: [PATCH 53/58] fix: parseLLMOutput Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 36 ++++++++- .../MLPlanExecuteAndReflectAgentRunner.java | 53 +++++++------ .../TransportUpdateModelGroupActionTests.java | 74 +++++++++++++++++++ 3 files changed, 136 insertions(+), 27 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 9ee06d4fd4..191b9204ad 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -14,6 +14,7 @@ import static org.opensearch.ml.common.utils.ToolUtils.parseResponse; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.DISABLE_TRACE; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.INTERACTIONS_PREFIX; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_CHAT_HISTORY_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_PREFIX; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.PROMPT_SUFFIX; @@ -86,6 +87,7 @@ import org.opensearch.transport.client.Client; import com.google.common.annotations.VisibleForTesting; +import com.jayway.jsonpath.JsonPath; import lombok.Data; import lombok.NoArgsConstructor; @@ -1117,10 +1119,38 @@ void generateLLMSummary(List stepsSummary, LLMSpec llmSpec, String public String extractSummaryFromResponse(MLTaskResponse response) { try { - String outputString = outputToOutputString(response.getOutput()); - if (outputString != null && !outputString.trim().isEmpty()) { - return outputString.trim(); + ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); + if (output == null || output.getMlModelOutputs() == null || output.getMlModelOutputs().isEmpty()) { + return null; } + + ModelTensors tensors = output.getMlModelOutputs().getFirst(); + if (tensors == null || tensors.getMlModelTensors() == null || tensors.getMlModelTensors().isEmpty()) { + return null; + } + + ModelTensor tensor = tensors.getMlModelTensors().getFirst(); + if (tensor.getResult() != null) { + return tensor.getResult().trim(); + } + + if (tensor.getDataAsMap() == null) { + return null; + } + + Map dataMap = tensor.getDataAsMap(); + if (dataMap.containsKey("response")) { + return String.valueOf(dataMap.get("response")).trim(); + } + + if (dataMap.containsKey("output")) { + Object outputObj = JsonPath.read(dataMap, LLM_RESPONSE_FILTER); + if (outputObj != null) { + return String.valueOf(outputObj).trim(); + } + } + + log.error("Summary generate error. No result/response field found. Available fields: {}", dataMap.keySet()); return null; } catch (Exception e) { log.error("Failed to extract summary from response", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index a2fe844c5e..ea354b5c71 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -835,32 +835,37 @@ private void generateSummary(LLMSpec llmSpec, List completedSteps, Strin private String extractSummaryFromResponse(MLTaskResponse response) { try { ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); - if (output != null && output.getMlModelOutputs() != null && !output.getMlModelOutputs().isEmpty()) { - ModelTensors tensors = output.getMlModelOutputs().getFirst(); - if (tensors != null && tensors.getMlModelTensors() != null && !tensors.getMlModelTensors().isEmpty()) { - ModelTensor tensor = tensors.getMlModelTensors().getFirst(); - if (tensor.getResult() != null) { - return tensor.getResult().trim(); - } - if (tensor.getDataAsMap() != null) { - Map dataMap = tensor.getDataAsMap(); - if (dataMap.containsKey(RESPONSE_FIELD)) { - return String.valueOf(dataMap.get(RESPONSE_FIELD)).trim(); - } - if (dataMap.containsKey("output")) { - Object outputObj = JsonPath.read(dataMap, "$.output.message.content[0].text"); - if (outputObj != null) { - return String.valueOf(outputObj).trim(); - } - } - } - log - .error( - "Summary generate error. No result/response field. Available: {}", - tensor.getDataAsMap() != null ? tensor.getDataAsMap().keySet() : "null" - ); + if (output == null || output.getMlModelOutputs() == null || output.getMlModelOutputs().isEmpty()) { + return null; + } + + ModelTensors tensors = output.getMlModelOutputs().getFirst(); + if (tensors == null || tensors.getMlModelTensors() == null || tensors.getMlModelTensors().isEmpty()) { + return null; + } + + ModelTensor tensor = tensors.getMlModelTensors().getFirst(); + if (tensor.getResult() != null) { + return tensor.getResult().trim(); + } + + if (tensor.getDataAsMap() == null) { + return null; + } + + Map dataMap = tensor.getDataAsMap(); + if (dataMap.containsKey(RESPONSE_FIELD)) { + return String.valueOf(dataMap.get(RESPONSE_FIELD)).trim(); + } + + if (dataMap.containsKey("output")) { + Object outputObj = JsonPath.read(dataMap, LLM_RESPONSE_FILTER); + if (outputObj != null) { + return String.valueOf(outputObj).trim(); } } + + log.error("Summary generate error. No result/response field found. Available fields: {}", dataMap.keySet()); return null; } catch (Exception e) { log.error("Summary extraction failed", e); diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index c62716d793..cf1fc05dbc 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -454,6 +454,80 @@ public void test_ExceptionSecurityDisabledCluster() { ); } + public void test_Update_RSC_FeatureEnabled_TypeEnabled_SkipsLegacyValidation() throws Exception { + // Enable RSC fast-path. + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + // when(rsc.isFeatureEnabledForType(any())).thenReturn(true); + + // No ACL changes in request (so even legacy would pass, but we won't go there). + MLUpdateModelGroupRequest req = prepareRequest(null, null, null); + + transportUpdateModelGroupAction.doExecute(null, req, actionListener); + + // Legacy validation was skipped. + verify(modelAccessControlHelper, times(0)).isSecurityEnabledAndModelAccessControlEnabled(any()); + verify(modelAccessControlHelper, times(0)).isOwner(any(), any()); + verify(modelAccessControlHelper, times(0)).isAdmin(any()); + + // Update succeeded. + ArgumentCaptor captor = ArgumentCaptor.forClass(MLUpdateModelGroupResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals("Updated", captor.getValue().getStatus()); + } + + public void test_Update_RSC_FeatureEnabled_TypeDisabled_UsesLegacyValidation() throws Exception { + // RSC feature on, but type disabled → legacy path. + ResourceSharingClient rsc = mock(ResourceSharingClient.class); + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); + // when(rsc.isFeatureEnabledForType(any())).thenReturn(false); + + // Allow legacy validation to pass: + // security/model-access-control enabled: + when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); + // user is allowed to update: + when(modelAccessControlHelper.isOwner(any(), any())).thenReturn(true); + when(modelAccessControlHelper.isOwnerStillHasPermission(any(), any())).thenReturn(true); + when(modelAccessControlHelper.isAdmin(any())).thenReturn(false); + when(modelAccessControlHelper.isUserHasBackendRole(any(), any())).thenReturn(true); + + MLUpdateModelGroupRequest req = prepareRequest(null, null, null); + + transportUpdateModelGroupAction.doExecute(null, req, actionListener); + + // Legacy path consulted helper + verify(modelAccessControlHelper, times(1)).isSecurityEnabledAndModelAccessControlEnabled(any()); + + // Update succeeded + ArgumentCaptor captor = ArgumentCaptor.forClass(MLUpdateModelGroupResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals("Updated", captor.getValue().getStatus()); + } + + public void test_Update_RSC_FeatureDisabled_UsesLegacyValidation() throws Exception { + // Entire feature disabled → legacy path. + ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); + + // Allow legacy validation to pass: + when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); + when(modelAccessControlHelper.isOwner(any(), any())).thenReturn(true); + when(modelAccessControlHelper.isOwnerStillHasPermission(any(), any())).thenReturn(true); + when(modelAccessControlHelper.isAdmin(any())).thenReturn(false); + when(modelAccessControlHelper.isUserHasBackendRole(any(), any())).thenReturn(true); + + MLUpdateModelGroupRequest req = prepareRequest(null, null, null); + + transportUpdateModelGroupAction.doExecute(null, req, actionListener); + + // Legacy path consulted helper + verify(modelAccessControlHelper, times(1)).isSecurityEnabledAndModelAccessControlEnabled(any()); + + // Update succeeded + ArgumentCaptor captor = ArgumentCaptor.forClass(MLUpdateModelGroupResponse.class); + verify(actionListener).onResponse(captor.capture()); + assertEquals("Updated", captor.getValue().getStatus()); + } + private MLUpdateModelGroupRequest prepareRequest(List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { MLUpdateModelGroupInput UpdateModelGroupInput = MLUpdateModelGroupInput .builder() From 68ba5d4cdd309641ec2acbfc308c7ab918618831 Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Tue, 4 Nov 2025 15:16:30 +0800 Subject: [PATCH 54/58] add: test cases Signed-off-by: Jiaru Jiang --- .../agent/MLChatAgentRunnerTest.java | 37 ++++ ...LPlanExecuteAndReflectAgentRunnerTest.java | 194 ++++++++++++++++++ .../TransportUpdateModelGroupActionTests.java | 4 +- 3 files changed, 233 insertions(+), 2 deletions(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index e8da40b428..9d3e2bcb43 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1463,4 +1463,41 @@ public void testGenerateLLMSummaryWithNullSteps() { verify(listener).onFailure(any(IllegalArgumentException.class)); } + + @Test + public void testExtractSummaryFromResponse_WithResponseField() { + Map dataMap = new HashMap<>(); + dataMap.put("response", "Summary from response field"); + ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build(); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals("Summary from response field", result); + } + + @Test + public void testExtractSummaryFromResponse_WithNullDataMap() { + ModelTensor tensor = ModelTensor.builder().build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build(); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals(null, result); + } + + @Test + public void testExtractSummaryFromResponse_WithEmptyDataMap() { + Map dataMap = new HashMap<>(); + dataMap.put("other_field", "some value"); + ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build(); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + + String result = mlChatAgentRunner.extractSummaryFromResponse(response); + assertEquals(null, result); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java index b730c139d6..d07951cdef 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java @@ -913,4 +913,198 @@ public void testUpdateTaskWithExecutorAgentInfo() { mlTaskUtilsMockedStatic.verify(() -> MLTaskUtils.updateMLTaskDirectly(eq(taskId), eq(taskUpdates), eq(client), any())); } } + + @Test + public void testExecutionWithNullStepResult() { + MLAgent mlAgent = createMLAgentWithTools(); + + // Setup LLM response for planning phase - returns steps to execute + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", "{\"steps\":[\"step1\"], \"result\":\"\"}")) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + // Setup executor response with tensor that has null dataMap - this will hit line 465 + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor memoryIdTensor = ModelTensor.builder().name(MLAgentExecutor.MEMORY_ID).result("test_memory_id").build(); + ModelTensor parentIdTensor = ModelTensor.builder().name(MLAgentExecutor.PARENT_INTERACTION_ID).result("test_parent_id").build(); + // This tensor will return null from parseTensorDataMap, hitting the stepResult != null check + ModelTensor nullDataTensor = ModelTensor.builder().name("other").build(); + ModelTensors modelTensors = ModelTensors + .builder() + .mlModelTensors(Arrays.asList(memoryIdTensor, parentIdTensor, nullDataTensor)) + .build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + when(mlExecuteTaskResponse.getOutput()).thenReturn(mlModelTensorOutput); + listener.onResponse(mlExecuteTaskResponse); + return null; + }).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(MLExecuteTaskRequest.class), any()); + + Map params = new HashMap<>(); + params.put("question", "test question"); + params.put("parent_interaction_id", "test_parent_interaction_id"); + + // Capture the exception in the listener + doAnswer(invocation -> { + Exception e = invocation.getArgument(0); + assertTrue(e instanceof IllegalStateException); + assertEquals("No valid response found in ReAct agent output", e.getMessage()); + return null; + }).when(agentActionListener).onFailure(any()); + + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + // Verify that onFailure was called with the expected exception + verify(agentActionListener).onFailure(any(IllegalStateException.class)); + } + + @Test + public void testMaxStepsWithSingleCompletedStep() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(0); + listener.onResponse(Arrays.asList(Interaction.builder().id("i1").input("step1").response("").build())); + return null; + }).when(conversationIndexMemory).getMessages(any(), anyInt()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + String response = (String) ((ModelTensorOutput) objectCaptor.getValue()) + .getMlModelOutputs() + .get(1) + .getMlModelTensors() + .get(0) + .getDataAsMap() + .get("response"); + assertTrue(response.contains("Max Steps Limit (0) Reached")); + } + + @Test + public void testSummaryExtractionWithResultField() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor tensor = ModelTensor.builder().result("Summary from result").build(); + when(mlTaskResponse.getOutput()) + .thenReturn( + ModelTensorOutput + .builder() + .mlModelOutputs(Arrays.asList(ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build())) + .build() + ); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + String response = (String) ((ModelTensorOutput) objectCaptor.getValue()) + .getMlModelOutputs() + .get(1) + .getMlModelTensors() + .get(0) + .getDataAsMap() + .get("response"); + assertTrue(response.contains("Summary from result")); + } + + @Test + public void testSummaryExtractionWithEmptyResponse() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + ModelTensor tensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", " ")).build(); + when(mlTaskResponse.getOutput()) + .thenReturn( + ModelTensorOutput + .builder() + .mlModelOutputs(Arrays.asList(ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build())) + .build() + ); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + String response = (String) ((ModelTensorOutput) objectCaptor.getValue()) + .getMlModelOutputs() + .get(1) + .getMlModelTensors() + .get(0) + .getDataAsMap() + .get("response"); + assertTrue(response.contains("Max Steps Limit")); + } + + @Test + public void testSummaryExtractionWithNullOutput() { + MLAgent mlAgent = createMLAgentWithTools(); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + when(mlTaskResponse.getOutput()).thenReturn(null); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateResponse); + return null; + }).when(mlMemoryManager).updateInteraction(any(), any(), any()); + + Map params = new HashMap<>(); + params.put("question", "test"); + params.put("parent_interaction_id", "pid"); + params.put("max_steps", "0"); + mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener); + + verify(agentActionListener).onResponse(any()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index cf1fc05dbc..ed4766799a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -458,7 +458,7 @@ public void test_Update_RSC_FeatureEnabled_TypeEnabled_SkipsLegacyValidation() t // Enable RSC fast-path. ResourceSharingClient rsc = mock(ResourceSharingClient.class); ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); - // when(rsc.isFeatureEnabledForType(any())).thenReturn(true); + when(rsc.isFeatureEnabledForType(any())).thenReturn(true); // No ACL changes in request (so even legacy would pass, but we won't go there). MLUpdateModelGroupRequest req = prepareRequest(null, null, null); @@ -480,7 +480,7 @@ public void test_Update_RSC_FeatureEnabled_TypeDisabled_UsesLegacyValidation() t // RSC feature on, but type disabled → legacy path. ResourceSharingClient rsc = mock(ResourceSharingClient.class); ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); - // when(rsc.isFeatureEnabledForType(any())).thenReturn(false); + when(rsc.isFeatureEnabledForType(any())).thenReturn(false); // Allow legacy validation to pass: // security/model-access-control enabled: From 691efee5d349ede2a7a374f5fc4de4f83b2f49bb Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Thu, 6 Nov 2025 16:01:46 +0800 Subject: [PATCH 55/58] add: test cases for fallback Signed-off-by: Jiaru Jiang --- .../algorithms/agent/MLChatAgentRunner.java | 3 -- .../agent/MLChatAgentRunnerTest.java | 47 +++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 191b9204ad..0b4f9aacc0 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -451,7 +451,6 @@ private void runReAct( lastThought, maxIterations, tools, - tmpParameters, llm, tenantId ); @@ -558,7 +557,6 @@ private void runReAct( lastThought, maxIterations, tools, - tmpParameters, llm, tenantId ); @@ -992,7 +990,6 @@ private void handleMaxIterationsReached( AtomicReference lastThought, int maxIterations, Map tools, - Map parameters, LLMSpec llmSpec, String tenantId ) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 9d3e2bcb43..c22af74f1c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -1500,4 +1500,51 @@ public void testExtractSummaryFromResponse_WithEmptyDataMap() { String result = mlChatAgentRunner.extractSummaryFromResponse(response); assertEquals(null, result); } + + @Test + public void testExtractSummaryFromResponse_ThrowsException_FallbackStrategyUsed() { + LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build(); + MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build(); + final MLAgent mlAgent = MLAgent + .builder() + .name("TestAgent") + .type(MLAgentType.CONVERSATIONAL.name()) + .llm(llmSpec) + .memory(mlMemorySpec) + .tools(Arrays.asList(firstToolSpec)) + .build(); + + Mockito.reset(client); + Mockito.reset(firstTool); + when(firstTool.getName()).thenReturn(FIRST_TOOL); + when(firstTool.validate(Mockito.anyMap())).thenReturn(true); + Mockito.doAnswer(generateToolResponse("First tool response")).when(firstTool).run(Mockito.anyMap(), any()); + + Mockito.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "Analyzing the problem", "action", FIRST_TOOL))).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + Map invalidDataMap = new HashMap<>(); + invalidDataMap.put("output", new HashMap<>()); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(invalidDataMap).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build(); + listener.onResponse(mlTaskResponse); + return null; + }).when(client).execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class)); + + Map params = new HashMap<>(); + params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); + + verify(agentActionListener).onResponse(objectCaptor.capture()); + Object capturedResponse = objectCaptor.getValue(); + assertTrue(capturedResponse instanceof ModelTensorOutput); + + ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse; + List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); + assertEquals(1, agentOutput.size()); + + String response = (String) agentOutput.get(0).getDataAsMap().get("response"); + assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: Analyzing the problem", response); + } } From 123420992718cd4a46bbad8083f69679fd03eefd Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Wed, 12 Nov 2025 18:24:36 +0800 Subject: [PATCH 56/58] fix:spotlessApply Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunner.java | 11 ++++++++--- .../agent/MLPlanExecuteAndReflectAgentRunner.java | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 0b4f9aacc0..884776cf8f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -36,8 +36,8 @@ import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; -import static org.opensearch.ml.engine.tools.ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY; import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE; +import static org.opensearch.ml.engine.tools.ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY; import java.security.PrivilegedActionException; import java.util.ArrayList; @@ -60,13 +60,16 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLMemoryType; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.contextmanager.ContextManagerContext; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.hooks.HookRegistry; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.memory.Memory; import org.opensearch.ml.common.memory.Message; import org.opensearch.ml.common.output.model.ModelTensor; @@ -74,6 +77,8 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.agents.AgentContextUtil; import org.opensearch.ml.engine.encryptor.Encryptor; @@ -984,7 +989,7 @@ private void handleMaxIterationsReached( boolean verbose, boolean traceDisabled, List traceTensors, - ConversationIndexMemory conversationIndexMemory, + Memory memory, AtomicInteger traceNumber, Map additionalInfo, AtomicReference lastThought, @@ -1002,7 +1007,7 @@ private void handleMaxIterationsReached( verbose, traceDisabled, traceTensors, - conversationIndexMemory, + memory, traceNumber, additionalInfo, response, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index ea354b5c71..360f9a50d3 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -758,7 +758,7 @@ private void handleMaxStepsReached( ActionListener responseListener = ActionListener.wrap(response -> { saveAndReturnFinalResult( - (ConversationIndexMemory) memory, + memory, parentInteractionId, allParams.get(EXECUTOR_AGENT_MEMORY_ID_FIELD), allParams.get(EXECUTOR_AGENT_PARENT_INTERACTION_ID_FIELD), From 62086c0b58b0ea331c69127ebd59e740e5ded65c Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Thu, 13 Nov 2025 10:31:13 +0800 Subject: [PATCH 57/58] fix:recover file Signed-off-by: Jiaru Jiang --- .../ml/helper/ModelAccessControlHelper.java | 10 --- .../TransportUpdateModelGroupActionTests.java | 74 ------------------- 2 files changed, 84 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java index 77fe7a370b..adcc5d196e 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ModelAccessControlHelper.java @@ -288,16 +288,6 @@ public void checkModelGroupPermission(MLModelGroup mlModelGroup, User user, Acti } } - /** - * Checks whether to utilize new ResourceAuthz - * @param resourceType for which to decide whether to use resource authz - * @return true if the resource-sharing feature is enabled, false otherwise. - */ - public static boolean shouldUseResourceAuthz(String resourceType) { - var client = ResourceSharingClientAccessor.getInstance().getResourceSharingClient(); - return client != null && client.isFeatureEnabledForType(resourceType); - } - public boolean skipModelAccessControl(User user) { // Case 1: user == null when 1. Security is disabled. 2. When user is super-admin // Case 2: If Security is enabled and filter is disabled, proceed with search as diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index ed4766799a..c62716d793 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -454,80 +454,6 @@ public void test_ExceptionSecurityDisabledCluster() { ); } - public void test_Update_RSC_FeatureEnabled_TypeEnabled_SkipsLegacyValidation() throws Exception { - // Enable RSC fast-path. - ResourceSharingClient rsc = mock(ResourceSharingClient.class); - ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); - when(rsc.isFeatureEnabledForType(any())).thenReturn(true); - - // No ACL changes in request (so even legacy would pass, but we won't go there). - MLUpdateModelGroupRequest req = prepareRequest(null, null, null); - - transportUpdateModelGroupAction.doExecute(null, req, actionListener); - - // Legacy validation was skipped. - verify(modelAccessControlHelper, times(0)).isSecurityEnabledAndModelAccessControlEnabled(any()); - verify(modelAccessControlHelper, times(0)).isOwner(any(), any()); - verify(modelAccessControlHelper, times(0)).isAdmin(any()); - - // Update succeeded. - ArgumentCaptor captor = ArgumentCaptor.forClass(MLUpdateModelGroupResponse.class); - verify(actionListener).onResponse(captor.capture()); - assertEquals("Updated", captor.getValue().getStatus()); - } - - public void test_Update_RSC_FeatureEnabled_TypeDisabled_UsesLegacyValidation() throws Exception { - // RSC feature on, but type disabled → legacy path. - ResourceSharingClient rsc = mock(ResourceSharingClient.class); - ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc); - when(rsc.isFeatureEnabledForType(any())).thenReturn(false); - - // Allow legacy validation to pass: - // security/model-access-control enabled: - when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); - // user is allowed to update: - when(modelAccessControlHelper.isOwner(any(), any())).thenReturn(true); - when(modelAccessControlHelper.isOwnerStillHasPermission(any(), any())).thenReturn(true); - when(modelAccessControlHelper.isAdmin(any())).thenReturn(false); - when(modelAccessControlHelper.isUserHasBackendRole(any(), any())).thenReturn(true); - - MLUpdateModelGroupRequest req = prepareRequest(null, null, null); - - transportUpdateModelGroupAction.doExecute(null, req, actionListener); - - // Legacy path consulted helper - verify(modelAccessControlHelper, times(1)).isSecurityEnabledAndModelAccessControlEnabled(any()); - - // Update succeeded - ArgumentCaptor captor = ArgumentCaptor.forClass(MLUpdateModelGroupResponse.class); - verify(actionListener).onResponse(captor.capture()); - assertEquals("Updated", captor.getValue().getStatus()); - } - - public void test_Update_RSC_FeatureDisabled_UsesLegacyValidation() throws Exception { - // Entire feature disabled → legacy path. - ResourceSharingClientAccessor.getInstance().setResourceSharingClient(null); - - // Allow legacy validation to pass: - when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); - when(modelAccessControlHelper.isOwner(any(), any())).thenReturn(true); - when(modelAccessControlHelper.isOwnerStillHasPermission(any(), any())).thenReturn(true); - when(modelAccessControlHelper.isAdmin(any())).thenReturn(false); - when(modelAccessControlHelper.isUserHasBackendRole(any(), any())).thenReturn(true); - - MLUpdateModelGroupRequest req = prepareRequest(null, null, null); - - transportUpdateModelGroupAction.doExecute(null, req, actionListener); - - // Legacy path consulted helper - verify(modelAccessControlHelper, times(1)).isSecurityEnabledAndModelAccessControlEnabled(any()); - - // Update succeeded - ArgumentCaptor captor = ArgumentCaptor.forClass(MLUpdateModelGroupResponse.class); - verify(actionListener).onResponse(captor.capture()); - assertEquals("Updated", captor.getValue().getStatus()); - } - private MLUpdateModelGroupRequest prepareRequest(List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { MLUpdateModelGroupInput UpdateModelGroupInput = MLUpdateModelGroupInput .builder() From e0dc13d08e38578f468335010534d1e2d73de21a Mon Sep 17 00:00:00 2001 From: Jiaru Jiang Date: Thu, 13 Nov 2025 11:15:09 +0800 Subject: [PATCH 58/58] fix:recover file Signed-off-by: Jiaru Jiang --- .../ml/engine/algorithms/agent/MLChatAgentRunnerTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index c22af74f1c..4dcbb88c54 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -681,7 +681,7 @@ public void testToolThrowException() { .when(firstTool) .run(Mockito.anyMap(), toolListenerCaptor.capture()); // Run the MLChatAgentRunner - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify that the tool's run method was called verify(firstTool).run(any(), any());