From e8f490c341d636b99b76d7e8dbbf31c195e47e96 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Fri, 5 Sep 2025 19:47:38 -0700 Subject: [PATCH 1/4] [Agentic Memory] Allow other LLMs for Fact Extraction Signed-off-by: rithin-pullela-aws --- .../MemoryContainerConstants.java | 211 ++++++- .../memory/MemoryProcessingService.java | 121 ++-- .../memory/MemoryProcessingServiceTests.java | 98 +--- .../ml/rest/RestMLAgenticMemoryIT.java | 520 ++++++++++++++++++ 4 files changed, 773 insertions(+), 177 deletions(-) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java index 7965f0e70a..ec7c6de438 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java @@ -102,16 +102,217 @@ public class MemoryContainerConstants { public static final String MAX_MESSAGES_EXCEEDED_ERROR = "Cannot process more than 10 messages in a single request"; // Memory decision fields + public static final String FACTS_FIELD = "facts"; public static final String MEMORY_DECISION_FIELD = "memory_decision"; public static final String OLD_MEMORY_FIELD = "old_memory"; public static final String RETRIEVED_FACTS_FIELD = "retrieved_facts"; public static final String EVENT_FIELD = "event"; public static final String SCORE_FIELD = "score"; - // LLM System Prompts - public static final String PERSONAL_INFORMATION_ORGANIZER_PROMPT = - "\nPersonal Information Organizer\nExtract and organize personal information shared within conversations.\n\nCarefully read the conversation.\nIdentify and extract any personal information shared by participants.\nFocus on details that help build a profile of the person, including but not limited to:\n\nNames and relationships\nProfessional information (job, company, role, responsibilities)\nPersonal interests and hobbies\nSkills and expertise\nPreferences and opinions\nGoals and aspirations\nChallenges or pain points\nBackground and experiences\nContact information (if shared)\nAvailability and schedule preferences\n\n\nOrganize each piece of information as a separate fact.\nEnsure facts are specific, clear, and preserve the original context.\nNever answer user's question or fulfill user's requirement. You are a personal information manager, not a helpful assistant.\nInclude the person who shared the information when relevant.\nDo not make assumptions or inferences beyond what is explicitly stated.\nIf no personal information is found, return an empty list.\n\n\nYou should always return and only return the extracted facts as a JSON object with a \"facts\" array.\n\n{\n \"facts\": [\n \"User's name is John Smith\",\n \"John works as a software engineer at TechCorp\",\n \"John enjoys hiking on weekends\",\n \"John is looking to improve his Python skills\"\n ]\n}\n\n\n"; + // ==== PERSONAL INFORMATION ORGANIZER PROMPT SECTIONS ==== + public static final String PERSONAL_INFO_ROLE = "Personal Information Organizer"; + public static final String PERSONAL_INFO_OBJECTIVE = "Extract and organize personal information shared within conversations."; - public static final String DEFAULT_UPDATE_MEMORY_PROMPT = - "You are a smart memory manager which controls the memory of a system.You will receive: 1. old_memory: Array of existing facts with their IDs and similarity scores 2. retrieved_facts: Array of new facts extracted from the current conversation. Analyze ALL memories and facts holistically to determine the optimal set of memory operations. Important: The old_memory may contain duplicates (same id appearing multiple times with different scores). Consider the highest score for each unique ID. You should only respond and always respond with a JSON object containing a \"memory_decision\" array that covers: - Every unique existing memory ID (with appropriate event: NONE, UPDATE, or DELETE) - New entries for facts that should be added (with event: ADD){\"memory_decision\": [{\"id\": \"existing_id_or_new_id\",\"text\": \"the fact text\",\"event\": \"ADD|UPDATE|DELETE|NONE\",\"old_memory\": \"original text (only for UPDATE events)\"}]}1. **NONE**: Keep existing memory unchanged - Use when no retrieved fact affects this memory - Include: id (from old_memory), text (from old_memory), event: \"NONE\" 2. **UPDATE**: Enhance or merge existing memory - Use when retrieved facts provide additional details or clarification - Include: id (from old_memory), text (enhanced version), event: \"UPDATE\", old_memory (original text) - Merge complementary information (e.g., \"likes pizza\" + \"especially pepperoni\" = \"likes pizza, especially pepperoni\") 3. **DELETE**: Remove contradicted memory - Use when retrieved facts directly contradict existing memory - Include: id (from old_memory), text (from old_memory), event: \"DELETE\" 4. **ADD**: Create new memory - Use for retrieved facts that represent genuinely new information - Include: id (generate new), text (the new fact), event: \"ADD\" - Only add if the fact is not already covered by existing or updated memories- Integrity: Never answer user's question or fulfill user's requirement. You are a smart memory manager, not a helpful assistant. - Process holistically: Consider all facts and memories together before making decisions - Avoid redundancy: Don't ADD a fact if it's already covered by an UPDATE - Merge related facts: If multiple retrieved facts relate to the same topic, consider combining them - Respect similarity scores: Higher scores indicate stronger matches - be more careful about updating high-score memories - Maintain consistency: Ensure your decisions don't create contradictions in the memory set - One decision per unique memory ID: If an ID appears multiple times in old_memory, make only one decision for it{\"old_memory\": [{\"id\": \"fact_001\", \"text\": \"Enjoys Italian food\", \"score\": 0.85},{\"id\": \"fact_002\", \"text\": \"Works at Google\", \"score\": 0.92},{\"id\": \"fact_001\", \"text\": \"Enjoys Italian food\", \"score\": 0.75},{\"id\": \"fact_003\", \"text\": \"Has a dog\", \"score\": 0.65}],\"retrieved_facts\": [\"Loves pasta and pizza\",\"Recently joined Amazon\",\"Has two dogs named Max and Bella\"]}{\"memory_decision\": [{\"id\": \"fact_001\",\"text\": \"Loves Italian food, especially pasta and pizza\",\"event\": \"UPDATE\",\"old_memory\": \"Enjoys Italian food\"},{\"id\": \"fact_002\",\"text\": \"Works at Google\",\"event\": \"DELETE\"},{\"id\": \"fact_003\",\"text\": \"Has two dogs named Max and Bella\",\"event\": \"UPDATE\",\"old_memory\": \"Has a dog\"},{\"id\": \"fact_004\",\"text\": \"Recently joined Amazon\",\"event\": \"ADD\"}]}"; + public static final String PERSONAL_INFO_BASIC_INSTRUCTIONS = "Carefully read the conversation.\n" + + "Identify and extract any personal information shared by participants."; + + public static final String PERSONAL_INFO_FOCUS_AREAS = + "Focus on details that help build a profile of the person, including but not limited to:\n" + + "\n" + + "Names and relationships\n" + + "Professional information (job, company, role, responsibilities)\n" + + "Personal interests and hobbies\n" + + "Skills and expertise\n" + + "Preferences and opinions\n" + + "Goals and aspirations\n" + + "Challenges or pain points\n" + + "Background and experiences\n" + + "Contact information (if shared)\n" + + "Availability and schedule preferences\n" + + "\n" + + ""; + + public static final String PERSONAL_INFO_PROCESSING_INSTRUCTIONS = + "Organize each piece of information as a separate fact.\n" + + "Ensure facts are specific, clear, and preserve the original context.\n" + + "Never answer user's question or fulfill user's requirement. You are a personal information manager, not a helpful assistant.\n" + + "Include the person who shared the information when relevant.\n" + + "Do not make assumptions or inferences beyond what is explicitly stated.\n" + + "If no personal information is found, return an empty list."; + + public static final String PERSONAL_INFO_RESPONSE_FORMAT_SCHEMA = "" + + "You should always return and only return the extracted facts as a JSON object with a \"facts\" array.\n" + + "\n" + + "{\n" + + " \"facts\": [\n" + + " \"User's name is John Smith\",\n" + + " \"John works as a software engineer at TechCorp\",\n" + + " \"John enjoys hiking on weekends\",\n" + + " \"John is looking to improve his Python skills\"\n" + + " ]\n" + + "}\n" + + "\n"; + + public static final String PERSONAL_INFO_OUTPUT_FORMAT_INSTRUCTIONS = + """ + - Return EXACTLY ONE JSON object containing a "facts" array. + - Output NOTHING else before or after it. + - Do NOT use code fences or markdown: no backticks (`), no ```json, no ```. + - Do NOT wrap in quotes or prose: no single quotes ('), no smart quotes (' " "), no angle brackets (< >), no XML/HTML, no lists, no headers, no ellipses. + - Use valid JSON only: standard double quotes (") for all keys/strings; no comments; no trailing commas. + - The "facts" array should contain strings with extracted personal information. + - If no personal information is found, return {"facts": []}. + """; + + public static final String PERSONAL_INFORMATION_ORGANIZER_PROMPT = "\n" + + "" + + PERSONAL_INFO_ROLE + + "\n" + + "" + + PERSONAL_INFO_OBJECTIVE + + "\n" + + "\n" + + PERSONAL_INFO_BASIC_INSTRUCTIONS + + "\n" + + PERSONAL_INFO_FOCUS_AREAS + + "\n" + + PERSONAL_INFO_PROCESSING_INSTRUCTIONS + + "\n" + + "\n" + + "\n" + + PERSONAL_INFO_RESPONSE_FORMAT_SCHEMA + + "\n" + + "\n" + + PERSONAL_INFO_OUTPUT_FORMAT_INSTRUCTIONS + + "\n" + + ""; + + // ==== DEFAULT UPDATE MEMORY PROMPT SECTIONS ==== + public static final String MEMORY_MANAGER_ROLE = "You are a smart memory manager which controls the memory of a system."; + + public static final String MEMORY_MANAGER_TASK = "You will receive:" + + "1. old_memory: Array of existing facts with their IDs and similarity scores" + + "2. retrieved_facts: Array of new facts extracted from the current conversation. " + + "Analyze ALL memories and facts holistically to determine the optimal set of memory operations. " + + "Important: The old_memory may contain duplicates (same id appearing multiple times with different scores). Consider the highest score for each unique ID. " + + "You should only respond and always respond with a JSON object containing a \"memory_decision\" array that covers: " + + "- Every unique existing memory ID (with appropriate event: NONE, UPDATE, or DELETE) " + + "- New entries for facts that should be added (with event: ADD)"; + + public static final String MEMORY_MANAGER_RESPONSE_FORMAT_SCHEMA = + "{\"memory_decision\": [{\"id\": \"existing_id_or_new_id\",\"text\": \"the fact text\",\"event\": \"ADD|UPDATE|DELETE|NONE\",\"old_memory\": \"original text (only for UPDATE events)\"}]}"; + + public static final String MEMORY_MANAGER_OUTPUT_FORMAT_INSTRUCTIONS = + """ + - Return EXACTLY ONE JSON object containing a "memory_decision" array. + - Output NOTHING else before or after it. + - Do NOT use code fences or markdown: no backticks (`), no ```json, no ```. + - Do NOT wrap in quotes or prose: no single quotes ('), no smart quotes (' " "), no angle brackets (< >), no XML/HTML, no lists, no headers, no ellipses. + - Use valid JSON only: standard double quotes (") for all keys/strings; no comments; no trailing commas. + - Each memory decision object must include: id, text, event, and old_memory (only for UPDATE events). + """; + + public static final String MEMORY_OPERATIONS_NONE = + "1. **NONE**: Keep existing memory unchanged - Use when no retrieved fact affects this memory - Include: id (from old_memory), text (from old_memory), event: \"NONE\""; + + public static final String MEMORY_OPERATIONS_UPDATE = + "2. **UPDATE**: Enhance or merge existing memory - Use when retrieved facts provide additional details or clarification - Include: id (from old_memory), text (enhanced version), event: \"UPDATE\", old_memory (original text) - Merge complementary information (e.g., \"likes pizza\" + \"especially pepperoni\" = \"likes pizza, especially pepperoni\")"; + + public static final String MEMORY_OPERATIONS_DELETE = + "3. **DELETE**: Remove contradicted memory - Use when retrieved facts directly contradict existing memory - Include: id (from old_memory), text (from old_memory), event: \"DELETE\""; + + public static final String MEMORY_OPERATIONS_ADD = + "4. **ADD**: Create new memory - Use for retrieved facts that represent genuinely new information - Include: id (generate new), text (the new fact), event: \"ADD\" - Only add if the fact is not already covered by existing or updated memories"; + + public static final String MEMORY_OPERATIONS = MEMORY_OPERATIONS_NONE + + " " + + MEMORY_OPERATIONS_UPDATE + + " " + + MEMORY_OPERATIONS_DELETE + + " " + + MEMORY_OPERATIONS_ADD; + + public static final String MEMORY_MANAGER_GUIDELINES = + "- Integrity: Never answer user's question or fulfill user's requirement. You are a smart memory manager, not a helpful assistant. " + + "- Process holistically: Consider all facts and memories together before making decisions " + + "- Avoid redundancy: Don't ADD a fact if it's already covered by an UPDATE " + + "- Merge related facts: If multiple retrieved facts relate to the same topic, consider combining them " + + "- Respect similarity scores: Higher scores indicate stronger matches - be more careful about updating high-score memories " + + "- Maintain consistency: Ensure your decisions don't create contradictions in the memory set " + + "- One decision per unique memory ID: If an ID appears multiple times in old_memory, make only one decision for it"; + + public static final String MEMORY_MANAGER_EXAMPLE_INPUT = """ + { + "old_memory": [ + {"id": "fact_001", "text": "Enjoys Italian food", "score": 0.85}, + {"id": "fact_002", "text": "Works at Google", "score": 0.92}, + {"id": "fact_001", "text": "Enjoys Italian food", "score": 0.75}, + {"id": "fact_003", "text": "Has a dog", "score": 0.65} + ], + "retrieved_facts": [ + "Loves pasta and pizza", + "Recently joined Amazon", + "Has two dogs named Max and Bella" + ] + } + """; + + public static final String MEMORY_MANAGER_EXAMPLE_OUTPUT = """ + { + "memory_decision": [ + { + "id": "fact_001", + "text": "Loves Italian food, especially pasta and pizza", + "event": "UPDATE", + "old_memory": "Enjoys Italian food" + }, + { + "id": "fact_002", + "text": "Works at Google", + "event": "DELETE" + }, + { + "id": "fact_003", + "text": "Has two dogs named Max and Bella", + "event": "UPDATE", + "old_memory": "Has a dog" + }, + { + "id": "fact_004", + "text": "Recently joined Amazon", + "event": "ADD" + } + ] + } + """; + + public static final String MEMORY_MANAGER_EXAMPLE = "" + + MEMORY_MANAGER_EXAMPLE_INPUT + + "" + + MEMORY_MANAGER_EXAMPLE_OUTPUT + + ""; + + public static final String DEFAULT_UPDATE_MEMORY_PROMPT = "" + + "" + + MEMORY_MANAGER_ROLE + + "" + + "" + + MEMORY_MANAGER_TASK + + "" + + "" + + MEMORY_MANAGER_RESPONSE_FORMAT_SCHEMA + + "" + + "" + + MEMORY_OPERATIONS + + "" + + "" + + MEMORY_MANAGER_GUIDELINES + + "" + + "" + + MEMORY_MANAGER_OUTPUT_FORMAT_INSTRUCTIONS + + "" + + "" + + MEMORY_MANAGER_EXAMPLE + + "" + + ""; } diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java index fdd7377240..6cd16f5214 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java @@ -5,11 +5,10 @@ package org.opensearch.ml.action.memorycontainer.memory; +import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.DEFAULT_UPDATE_MEMORY_PROMPT; -import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_DECISION_FIELD; -import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.PERSONAL_INFORMATION_ORGANIZER_PROMPT; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.*; import java.util.ArrayList; import java.util.HashMap; @@ -20,7 +19,6 @@ import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; @@ -35,6 +33,7 @@ import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.transport.client.Client; import lombok.extern.log4j.Log4j2; @@ -65,24 +64,12 @@ public void extractFactsFromConversation( stringParameters.put("system_prompt", PERSONAL_INFORMATION_ORGANIZER_PROMPT); try { - XContentBuilder messagesBuilder = jsonXContent.contentBuilder(); - messagesBuilder.startArray(); - + StringBuilder user_messages = new StringBuilder(); for (MessageInput message : messages) { - messagesBuilder.startObject(); - messagesBuilder.field("role", message.getRole() != null ? message.getRole() : "user"); - messagesBuilder.startArray("content"); - messagesBuilder.startObject(); - messagesBuilder.field("type", "text"); - messagesBuilder.field("text", message.getContent()); - messagesBuilder.endObject(); - messagesBuilder.endArray(); - messagesBuilder.endObject(); + user_messages.append(message.getContent()); } - - messagesBuilder.endArray(); - String messagesJson = messagesBuilder.toString(); - stringParameters.put("messages", messagesJson); + String messagesJson = user_messages.toString(); + stringParameters.put("messages", escapeJson(messagesJson)); log.debug("LLM request - processing {} messages", messages.size()); } catch (Exception e) { @@ -147,22 +134,7 @@ public void makeMemoryDecisions( String decisionRequestJson = decisionRequest.toJsonString(); try { - XContentBuilder messagesBuilder = jsonXContent.contentBuilder(); - messagesBuilder.startArray(); - messagesBuilder.startObject(); - messagesBuilder.field("role", "user"); - messagesBuilder.startArray("content"); - messagesBuilder.startObject(); - messagesBuilder.field("type", "text"); - messagesBuilder.field("text", decisionRequestJson); - messagesBuilder.endObject(); - messagesBuilder.endArray(); - messagesBuilder.endObject(); - messagesBuilder.endArray(); - - String messagesJson = messagesBuilder.toString(); - stringParameters.put("messages", messagesJson); - + stringParameters.put("messages", escapeJson(decisionRequestJson)); log .debug( "Making memory decisions for {} extracted facts and {} existing memories", @@ -214,43 +186,19 @@ private List parseFactsFromLLMResponse(MLOutput mlOutput) { return facts; } - for (int i = 0; i < modelTensors.getMlModelTensors().size(); i++) { - Map dataMap = modelTensors.getMlModelTensors().get(i).getDataAsMap(); - if (dataMap != null && dataMap.containsKey("content")) { - try { - List contentList = (List) dataMap.get("content"); - if (contentList != null && !contentList.isEmpty()) { - Map contentItem = (Map) contentList.get(0); - if (contentItem != null && contentItem.containsKey("text")) { - String responseStr = String.valueOf(contentItem.get("text")); - - try ( - XContentParser parser = jsonXContent - .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, responseStr) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - if ("facts".equals(fieldName)) { - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.nextToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - String fact = parser.text(); - facts.add(fact); - } - } else { - parser.skipChildren(); - } - } - } - } - } - } catch (Exception e) { - log.error("Failed to extract content from dataMap", e); - throw new IllegalArgumentException("Failed to extract content from LLM response", e); - } - break; + // Parse the JSON response to extract facts + try { + Map dataMap = modelTensors.getMlModelTensors().get(0).getDataAsMap(); + // if LLM Response does not contain FACTS_FIELD then fact extraction failed hence throw exception + if (dataMap == null || dataMap.isEmpty() || dataMap.get(FACTS_FIELD) == null) { + throw new IllegalArgumentException("Failed to parse facts from LLM response"); } + facts = (List) dataMap.get(FACTS_FIELD); + facts = facts != null ? facts : new ArrayList<>(); + } catch (Exception e) { + // Should not print the user data in logs + log.warn("Failed to parse facts from LLM response", e); + throw new IllegalArgumentException("Failed to parse facts from LLM response", e); } return facts; @@ -271,29 +219,16 @@ private List parseMemoryDecisions(MLTaskResponse response) { Map dataMap = tensors.get(0).getMlModelTensors().get(0).getDataAsMap(); - String responseContent = null; - if (dataMap.containsKey("response")) { - responseContent = (String) dataMap.get("response"); - } else if (dataMap.containsKey("content")) { - List> contentList = (List>) dataMap.get("content"); - if (contentList != null && !contentList.isEmpty()) { - Map firstContent = contentList.get(0); - responseContent = (String) firstContent.get("text"); - } - } + String responseContent = StringUtils.toJson(dataMap); - if (responseContent == null) { + if (responseContent == null || responseContent.isEmpty() || responseContent.equals("{}")) { throw new IllegalStateException("No response content found in LLM output"); } - // Clean response content - if (responseContent.startsWith("```json") && responseContent.endsWith("```")) { - responseContent = responseContent.substring(7, responseContent.length() - 3).trim(); - } else if (responseContent.startsWith("```") && responseContent.endsWith("```")) { - responseContent = responseContent.substring(3, responseContent.length() - 3).trim(); - } - + // Parse memory decisions using XContentParser List decisions = new ArrayList<>(); + boolean foundMemoryDecisionField = false; + try (XContentParser parser = jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, responseContent)) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -302,6 +237,7 @@ private List parseMemoryDecisions(MLTaskResponse response) { parser.nextToken(); if (MEMORY_DECISION_FIELD.equals(fieldName)) { + foundMemoryDecisionField = true; ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { decisions.add(MemoryDecision.parse(parser)); @@ -312,6 +248,11 @@ private List parseMemoryDecisions(MLTaskResponse response) { } } + // If MEMORY_DECISION_FIELD is not found in the LLM output, fail the parsing + if (!foundMemoryDecisionField) { + throw new IllegalStateException("LLM response does not contain required field: " + MEMORY_DECISION_FIELD); + } + return decisions; } catch (Exception e) { throw new RuntimeException("Failed to parse memory decisions", e); diff --git a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceTests.java b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceTests.java index def5dcbb3e..5bd74fe710 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingServiceTests.java @@ -10,6 +10,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.FACTS_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_DECISION_FIELD; import java.util.Arrays; import java.util.HashMap; @@ -243,7 +245,7 @@ public void testMakeMemoryDecisions_WithSearchResults() { } @Test - public void testExtractFactsFromConversation_ParseException() { + public void testExtractFactsFromConversation_InvalidFactsDataType() { List messages = Arrays.asList(MessageInput.builder().content("Hello").role("user").build()); MemoryStorageConfig storageConfig = mock(MemoryStorageConfig.class); when(storageConfig.getLlmModelId()).thenReturn("llm-model-123"); @@ -254,9 +256,7 @@ public void testExtractFactsFromConversation_ParseException() { ModelTensor mockTensor = mock(ModelTensor.class); Map dataMap = new HashMap<>(); - Map contentItem = new HashMap<>(); - contentItem.put("text", "invalid json"); - dataMap.put("content", Arrays.asList(contentItem)); + dataMap.put(FACTS_FIELD, "invalid array"); // Not a List when(mockResponse.getOutput()).thenReturn(mockOutput); when(mockOutput.getMlModelOutputs()).thenReturn(Arrays.asList(mockTensors)); @@ -342,7 +342,7 @@ public void testExtractFactsFromConversation_EmptyModelTensors() { } @Test - public void testExtractFactsFromConversation_NoContentKey() { + public void testExtractFactsFromConversation_NoFieldKey() { List messages = Arrays.asList(MessageInput.builder().content("Hello").role("user").build()); MemoryStorageConfig storageConfig = mock(MemoryStorageConfig.class); when(storageConfig.getLlmModelId()).thenReturn("llm-model-123"); @@ -368,11 +368,11 @@ public void testExtractFactsFromConversation_NoContentKey() { memoryProcessingService.extractFactsFromConversation(messages, storageConfig, factsListener); - verify(factsListener).onResponse(any(List.class)); + verify(factsListener).onFailure(any(IllegalArgumentException.class)); } @Test - public void testExtractFactsFromConversation_EmptyContentList() { + public void testExtractFactsFromConversation_EmptyFactsList() { List messages = Arrays.asList(MessageInput.builder().content("Hello").role("user").build()); MemoryStorageConfig storageConfig = mock(MemoryStorageConfig.class); when(storageConfig.getLlmModelId()).thenReturn("llm-model-123"); @@ -383,7 +383,7 @@ public void testExtractFactsFromConversation_EmptyContentList() { ModelTensor mockTensor = mock(ModelTensor.class); Map dataMap = new HashMap<>(); - dataMap.put("content", Arrays.asList()); + dataMap.put(FACTS_FIELD, Arrays.asList()); when(mockResponse.getOutput()).thenReturn(mockOutput); when(mockOutput.getMlModelOutputs()).thenReturn(Arrays.asList(mockTensors)); @@ -402,7 +402,7 @@ public void testExtractFactsFromConversation_EmptyContentList() { } @Test - public void testExtractFactsFromConversation_NoTextKey() { + public void testExtractFactsFromConversation_NoFactsKey() { List messages = Arrays.asList(MessageInput.builder().content("Hello").role("user").build()); MemoryStorageConfig storageConfig = mock(MemoryStorageConfig.class); when(storageConfig.getLlmModelId()).thenReturn("llm-model-123"); @@ -413,9 +413,7 @@ public void testExtractFactsFromConversation_NoTextKey() { ModelTensor mockTensor = mock(ModelTensor.class); Map dataMap = new HashMap<>(); - Map contentItem = new HashMap<>(); - contentItem.put("other", "value"); - dataMap.put("content", Arrays.asList(contentItem)); + dataMap.put("other", "value"); when(mockResponse.getOutput()).thenReturn(mockOutput); when(mockOutput.getMlModelOutputs()).thenReturn(Arrays.asList(mockTensors)); @@ -430,7 +428,7 @@ public void testExtractFactsFromConversation_NoTextKey() { memoryProcessingService.extractFactsFromConversation(messages, storageConfig, factsListener); - verify(factsListener).onResponse(any(List.class)); + verify(factsListener).onFailure(any(IllegalArgumentException.class)); } @Test @@ -479,71 +477,7 @@ public void testMakeMemoryDecisions_EmptyTensors() { } @Test - public void testMakeMemoryDecisions_ContentFormat() { - List facts = Arrays.asList("User name is John"); - List searchResults = Arrays.asList(); - MemoryStorageConfig storageConfig = mock(MemoryStorageConfig.class); - when(storageConfig.getLlmModelId()).thenReturn("llm-model-123"); - - MLTaskResponse mockResponse = mock(MLTaskResponse.class); - ModelTensorOutput mockOutput = mock(ModelTensorOutput.class); - ModelTensors mockTensors = mock(ModelTensors.class); - ModelTensor mockTensor = mock(ModelTensor.class); - - Map dataMap = new HashMap<>(); - Map contentItem = new HashMap<>(); - contentItem.put("text", "{\"memory_decisions\": []}"); - dataMap.put("content", Arrays.asList(contentItem)); - - when(mockResponse.getOutput()).thenReturn(mockOutput); - when(mockOutput.getMlModelOutputs()).thenReturn(Arrays.asList(mockTensors)); - when(mockTensors.getMlModelTensors()).thenReturn(Arrays.asList(mockTensor)); - when(mockTensor.getDataAsMap()).thenReturn((Map) dataMap); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mockResponse); - return null; - }).when(client).execute(any(), any(), any()); - - memoryProcessingService.makeMemoryDecisions(facts, searchResults, storageConfig, decisionsListener); - - verify(decisionsListener).onResponse(any(List.class)); - } - - @Test - public void testMakeMemoryDecisions_JsonCodeBlock() { - List facts = Arrays.asList("User name is John"); - List searchResults = Arrays.asList(); - MemoryStorageConfig storageConfig = mock(MemoryStorageConfig.class); - when(storageConfig.getLlmModelId()).thenReturn("llm-model-123"); - - MLTaskResponse mockResponse = mock(MLTaskResponse.class); - ModelTensorOutput mockOutput = mock(ModelTensorOutput.class); - ModelTensors mockTensors = mock(ModelTensors.class); - ModelTensor mockTensor = mock(ModelTensor.class); - - Map dataMap = new HashMap<>(); - dataMap.put("response", "```json\n{\"memory_decisions\": []}\n```"); - - when(mockResponse.getOutput()).thenReturn(mockOutput); - when(mockOutput.getMlModelOutputs()).thenReturn(Arrays.asList(mockTensors)); - when(mockTensors.getMlModelTensors()).thenReturn(Arrays.asList(mockTensor)); - when(mockTensor.getDataAsMap()).thenReturn((Map) dataMap); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mockResponse); - return null; - }).when(client).execute(any(), any(), any()); - - memoryProcessingService.makeMemoryDecisions(facts, searchResults, storageConfig, decisionsListener); - - verify(decisionsListener).onResponse(any(List.class)); - } - - @Test - public void testMakeMemoryDecisions_PlainCodeBlock() { + public void testMakeMemoryDecisions_EmptyDecisions() { List facts = Arrays.asList("User name is John"); List searchResults = Arrays.asList(); MemoryStorageConfig storageConfig = mock(MemoryStorageConfig.class); @@ -555,7 +489,7 @@ public void testMakeMemoryDecisions_PlainCodeBlock() { ModelTensor mockTensor = mock(ModelTensor.class); Map dataMap = new HashMap<>(); - dataMap.put("response", "```\n{\"memory_decisions\": []}\n```"); + dataMap.put(MEMORY_DECISION_FIELD, Arrays.asList()); when(mockResponse.getOutput()).thenReturn(mockOutput); when(mockOutput.getMlModelOutputs()).thenReturn(Arrays.asList(mockTensors)); @@ -616,9 +550,9 @@ public void testExtractFactsFromConversation_WithOtherFields() { ModelTensor mockTensor = mock(ModelTensor.class); Map dataMap = new HashMap<>(); - Map contentItem = new HashMap<>(); - contentItem.put("text", "{\"facts\": [\"User name is John\"], \"other_field\": \"value\", \"metadata\": {\"key\": \"value\"}}"); - dataMap.put("content", Arrays.asList(contentItem)); + dataMap.put(FACTS_FIELD, Arrays.asList("User name is John")); + dataMap.put("other_field", "value"); + dataMap.put("metadata", Map.of("key", "value")); when(mockResponse.getOutput()).thenReturn(mockOutput); when(mockOutput.getMlModelOutputs()).thenReturn(Arrays.asList(mockTensors)); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java new file mode 100644 index 0000000000..d21835a890 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java @@ -0,0 +1,520 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.StringEntity; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.client.Response; +import org.opensearch.ml.utils.TestHelper; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class RestMLAgenticMemoryIT extends MLCommonsRestTestCase { + + private static final String OPENAI_KEY = System.getProperty("OPENAI_KEY"); + private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); + private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); + private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); + private static final String GITHUB_CI_AWS_REGION = "us-west-2"; + + private static final String TEST_AGENT_ID = "test_agent_123"; + private static final String TEST_SESSION_ID = "test_session_456"; + + private final String openaiConnectorEntity = "{\n" + + " \"name\": \"My openai connector: gpt-4o-mini\",\n" + + " \"description\": \"The connector to openai chat model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"http\",\n" + + " \"parameters\": {\n" + + " \"model\": \"gpt-4o-mini\",\n" + + " \"response_filter\": \"$.choices[0].message.content\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + OPENAI_KEY + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://api.openai.com/v1/chat/completions\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": [{\\\"role\\\":\\\"system\\\",\\\"content\\\":\\\"${parameters.system_prompt}\\\"},{\\\"role\\\":\\\"user\\\",\\\"content\\\":\\\"${parameters.messages}\\\"}]}\"\n" + + " }\n" + + " ]\n" + + "}"; + + private final String claudeConnectorEntity = "{\n" + + " \"name\": \"Amazon Bedrock Connector: LLM\",\n" + + " \"description\": \"The connector to bedrock Claude 3.7 sonnet model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"max_tokens\": 8000,\n" + + " \"temperature\": 1,\n" + + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" + + " \"model\": \"us.anthropic.claude-3-7-sonnet-20250219-v1:0\",\n" + + " \"response_filter\": \"$.content[0].text\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [{\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"headers\": {\"content-type\": \"application/json\"},\n" + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/invoke\",\n" + + " \"request_body\": \"{ \\\"system\\\": \\\"${parameters.system_prompt}\\\", \\\"anthropic_version\\\": \\\"${parameters.anthropic_version}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature}, \\\"messages\\\": [{\\\"role\\\":\\\"user\\\",\\\"content\\\":[{ \\\"type\\\": \\\"text\\\", \\\"text\\\":\\\"${parameters.messages}\\\"}]}]}\"\n" + + " }]\n" + + "}"; + + @Before + public void setup() throws IOException, InterruptedException { + // Enable agentic memory + updateClusterSettings(ML_COMMONS_AGENTIC_MEMORY_ENABLED.getKey(), true); + } + + @Test + public void testCreateMemoryContainerAndAddMemories_openAI() throws IOException, InterruptedException { + // Create OpenAI model and memory container + String openaiModelId = registerLLMModel(); + String memoryContainerId = createMemoryContainerWithModel( + "OpenAI Test Memory Container", + "Store conversations with OpenAI model", + openaiModelId + ); + + try { + // Test adding memories with ADD operation + String addMemoryRequest = "{\n" + + " \"messages\": [\n" + + " {\n" + + " \"role\": \"user\",\n" + + " \"content\": \"I like popcorn\"\n" + + " },\n" + + " {\n" + + " \"role\": \"assistant\",\n" + + " \"content\": \"I like Indian food\"\n" + + " }\n" + + " ],\n" + + " \"agent_id\": \"" + + TEST_AGENT_ID + + "\",\n" + + " \"session_id\": \"" + + TEST_SESSION_ID + + "\"\n" + + "}"; + + Response response = addMemories(memoryContainerId, addMemoryRequest); + String responseBody = TestHelper.httpEntityToString(response.getEntity()); + @SuppressWarnings("unchecked") + Map responseMap = gson.fromJson(responseBody, Map.class); + + // Verify response structure + assertNotNull("Response should not be null", responseMap); + assertTrue("Response should contain results", responseMap.containsKey("results")); + assertTrue("Response should contain session_id", responseMap.containsKey("session_id")); + + @SuppressWarnings("unchecked") + List> results = (List>) responseMap.get("results"); + assertEquals("Should have 2 memories", 2, results.size()); + + // Verify first memory (ADD event) + Map firstMemory = results.get(0); + assertTrue("First memory should contain id", firstMemory.containsKey("id")); + assertTrue("First memory should contain text", firstMemory.containsKey("text")); + assertEquals("First memory should be ADD event", "ADD", firstMemory.get("event")); + assertTrue("First memory text should contain popcorn", firstMemory.get("text").toString().toLowerCase().contains("popcorn")); + + // Verify second memory (ADD event) + Map secondMemory = results.get(1); + assertTrue("Second memory should contain id", secondMemory.containsKey("id")); + assertTrue("Second memory should contain text", secondMemory.containsKey("text")); + assertEquals("Second memory should be ADD event", "ADD", secondMemory.get("event")); + assertTrue( + "Second memory text should contain indian food", + secondMemory.get("text").toString().toLowerCase().contains("indian") + ); + } finally { + // Clean up OpenAI memory container and model + deleteMemoryContainer(memoryContainerId); + deleteModel(openaiModelId); + } + } + + @Test + public void testUpdateMemoriesWithContradictoryInformation_openAI() throws IOException, InterruptedException { + // Create OpenAI model and memory container + String openaiModelId = registerLLMModel(); + String memoryContainerId = createMemoryContainerWithModel( + "OpenAI Test Memory Container", + "Store conversations with OpenAI model", + openaiModelId + ); + + try { + // First add initial memories + String initialMemoryRequest = "{\n" + + " \"messages\": [\n" + + " {\n" + + " \"role\": \"user\",\n" + + " \"content\": \"I like popcorn\"\n" + + " },\n" + + " {\n" + + " \"role\": \"assistant\",\n" + + " \"content\": \"I like Indian food\"\n" + + " }\n" + + " ],\n" + + " \"agent_id\": \"" + + TEST_AGENT_ID + + "\",\n" + + " \"session_id\": \"" + + TEST_SESSION_ID + + "\"\n" + + "}"; + + addMemories(memoryContainerId, initialMemoryRequest); + + // Now add contradictory information + String updateMemoryRequest = "{\n" + + " \"messages\": [\n" + + " {\n" + + " \"role\": \"user\",\n" + + " \"content\": \"Actually I don't like popcorn\"\n" + + " },\n" + + " {\n" + + " \"role\": \"assistant\",\n" + + " \"content\": \"I like Indian food\"\n" + + " }\n" + + " ],\n" + + " \"agent_id\": \"" + + TEST_AGENT_ID + + "\",\n" + + " \"session_id\": \"" + + TEST_SESSION_ID + + "\"\n" + + "}"; + + Response response = addMemories(memoryContainerId, updateMemoryRequest); + String responseBody = TestHelper.httpEntityToString(response.getEntity()); + @SuppressWarnings("unchecked") + Map responseMap = gson.fromJson(responseBody, Map.class); + + @SuppressWarnings("unchecked") + List> results = (List>) responseMap.get("results"); + assertTrue("Should have at least one result", results.size() >= 1); + + // Verify that we get either UPDATE or DELETE events for contradictory information + boolean hasUpdateOrDelete = results + .stream() + .anyMatch(result -> "UPDATE".equals(result.get("event")) || "DELETE".equals(result.get("event"))); + assertTrue("Should have UPDATE or DELETE event for contradictory information", hasUpdateOrDelete); + } finally { + // Clean up OpenAI memory container and model + deleteMemoryContainer(memoryContainerId); + deleteModel(openaiModelId); + } + } + + @Test + public void testMemoryContainerLifecycle() throws IOException, InterruptedException { + + String memoryContainerId = createMemoryContainerWithModel( + "OpenAI Test Memory Container", + "Store conversations with OpenAI model", + "TEST_MODEL_ID" + ); + + try { + // Test getting memory container + Response getResponse = getMemoryContainer(memoryContainerId); + String getResponseBody = TestHelper.httpEntityToString(getResponse.getEntity()); + @SuppressWarnings("unchecked") + Map getResponseMap = gson.fromJson(getResponseBody, Map.class); + + assertNotNull("Get container response should not be null", getResponseMap); + assertTrue("Container should have name", getResponseMap.containsKey("name")); + assertTrue("Container should have description", getResponseMap.containsKey("description")); + assertTrue("Container should have created time", getResponseMap.containsKey("created_time")); + assertTrue("Container should have last updated time", getResponseMap.containsKey("last_updated_time")); + assertTrue("Container should have memory storage config", getResponseMap.containsKey("memory_storage_config")); + assertTrue( + "Container should have memory storage config", + ((Map) getResponseMap.get("memory_storage_config")).containsKey("llm_model_id") + ); + assertEquals( + "llm_model_id should match", + ((Map) getResponseMap.get("memory_storage_config")).get("llm_model_id"), + "TEST_MODEL_ID" + ); + + } finally { + // Clean up OpenAI memory container and model + deleteMemoryContainer(memoryContainerId); + } + } + + @Test + public void testCreateMemoryContainerAndAddMemories_claude() throws IOException, InterruptedException { + if (awsCredentialsNotSet()) { + return; + } + + // Register Claude model and create memory container + String claudeModelId = registerClaudeModel(); + String claudeMemoryContainerId = createMemoryContainerWithModel( + "Claude Test Memory Container", + "Store conversations with Claude model", + claudeModelId + ); + + try { + // Test adding memories with ADD operation using Claude + String addMemoryRequest = "{\n" + + " \"messages\": [\n" + + " {\n" + + " \"role\": \"user\",\n" + + " \"content\": \"I like popcorn\"\n" + + " },\n" + + " {\n" + + " \"role\": \"assistant\",\n" + + " \"content\": \"I like Indian food\"\n" + + " }\n" + + " ],\n" + + " \"agent_id\": \"" + + TEST_AGENT_ID + + "\",\n" + + " \"session_id\": \"" + + TEST_SESSION_ID + + "\"\n" + + "}"; + + Response response = addMemories(claudeMemoryContainerId, addMemoryRequest); + String responseBody = TestHelper.httpEntityToString(response.getEntity()); + @SuppressWarnings("unchecked") + Map responseMap = gson.fromJson(responseBody, Map.class); + + // Verify response structure + assertNotNull("Response should not be null", responseMap); + assertTrue("Response should contain results", responseMap.containsKey("results")); + assertTrue("Response should contain session_id", responseMap.containsKey("session_id")); + + @SuppressWarnings("unchecked") + List> results = (List>) responseMap.get("results"); + assertEquals("Should have 2 memories", 2, results.size()); + + // Verify first memory (ADD event) + Map firstMemory = results.get(0); + assertTrue("First memory should contain id", firstMemory.containsKey("id")); + assertTrue("First memory should contain text", firstMemory.containsKey("text")); + assertEquals("First memory should be ADD event", "ADD", firstMemory.get("event")); + assertTrue("First memory text should contain popcorn", firstMemory.get("text").toString().toLowerCase().contains("popcorn")); + + // Verify second memory (ADD event) + Map secondMemory = results.get(1); + assertTrue("Second memory should contain id", secondMemory.containsKey("id")); + assertTrue("Second memory should contain text", secondMemory.containsKey("text")); + assertEquals("Second memory should be ADD event", "ADD", secondMemory.get("event")); + assertTrue( + "Second memory text should contain indian food", + secondMemory.get("text").toString().toLowerCase().contains("indian") + ); + } finally { + // Clean up Claude memory container and model + deleteMemoryContainer(claudeMemoryContainerId); + deleteModel(claudeModelId); + } + } + + @Test + public void testUpdateMemoriesWithContradictoryInformation_claude() throws IOException, InterruptedException { + if (awsCredentialsNotSet()) { + return; + } + + // Register Claude model and create memory container + String claudeModelId = registerClaudeModel(); + String claudeMemoryContainerId = createMemoryContainerWithModel( + "Claude Test Memory Container", + "Store conversations with Claude model", + claudeModelId + ); + + try { + // First add initial memories + String initialMemoryRequest = "{\n" + + " \"messages\": [\n" + + " {\n" + + " \"role\": \"user\",\n" + + " \"content\": \"I like popcorn\"\n" + + " },\n" + + " {\n" + + " \"role\": \"assistant\",\n" + + " \"content\": \"I like Indian food\"\n" + + " }\n" + + " ],\n" + + " \"agent_id\": \"" + + TEST_AGENT_ID + + "\",\n" + + " \"session_id\": \"" + + TEST_SESSION_ID + + "\"\n" + + "}"; + + addMemories(claudeMemoryContainerId, initialMemoryRequest); + + // Now add contradictory information + String updateMemoryRequest = "{\n" + + " \"messages\": [\n" + + " {\n" + + " \"role\": \"user\",\n" + + " \"content\": \"Actually I don't like popcorn\"\n" + + " },\n" + + " {\n" + + " \"role\": \"assistant\",\n" + + " \"content\": \"I like Indian food\"\n" + + " }\n" + + " ],\n" + + " \"agent_id\": \"" + + TEST_AGENT_ID + + "\",\n" + + " \"session_id\": \"" + + TEST_SESSION_ID + + "\"\n" + + "}"; + + Response response = addMemories(claudeMemoryContainerId, updateMemoryRequest); + String responseBody = TestHelper.httpEntityToString(response.getEntity()); + @SuppressWarnings("unchecked") + Map responseMap = gson.fromJson(responseBody, Map.class); + + @SuppressWarnings("unchecked") + List> results = (List>) responseMap.get("results"); + assertTrue("Should have at least one result", results.size() >= 1); + + // Verify that we get either UPDATE or DELETE events for contradictory information + boolean hasUpdateOrDelete = results + .stream() + .anyMatch(result -> "UPDATE".equals(result.get("event")) || "DELETE".equals(result.get("event"))); + assertTrue("Should have UPDATE or DELETE event for contradictory information", hasUpdateOrDelete); + } finally { + // Clean up Claude memory container and model + deleteMemoryContainer(claudeMemoryContainerId); + deleteModel(claudeModelId); + } + } + + private String registerLLMModel() throws IOException, InterruptedException { + String openaiModelName = "openai gpt-4o-mini model " + randomAlphaOfLength(5); + return registerRemoteModel(openaiConnectorEntity, openaiModelName, true); + } + + private String createMemoryContainerWithModel(String name, String description, String modelId) throws IOException { + String createRequest = "{\n" + + " \"name\": \"" + + name + + "\",\n" + + " \"description\": \"" + + description + + "\",\n" + + " \"memory_storage_config\": {\n" + + " \"llm_model_id\": \"" + + modelId + + "\"\n" + + " }\n" + + "}"; + + Response response = TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/memory_containers/_create", + null, + new StringEntity(createRequest), + List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json")) + ); + + String responseBody = TestHelper.httpEntityToString(response.getEntity()); + @SuppressWarnings("unchecked") + Map responseMap = gson.fromJson(responseBody, Map.class); + return responseMap.get("memory_container_id"); + } + + private Response addMemories(String containerId, String requestBody) throws IOException { + return TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/memory_containers/" + containerId + "/memories", + null, + new StringEntity(requestBody), + List.of(new BasicHeader(HttpHeaders.CONTENT_TYPE, "application/json")) + ); + } + + private Response getMemoryContainer(String containerId) throws IOException { + return TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/memory_containers/" + containerId, null, "", List.of()); + } + + private void deleteMemoryContainer(String containerId) throws IOException { + TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/memory_containers/" + containerId, null, "", List.of()); + } + + private String registerClaudeModel() throws IOException, InterruptedException { + String claudeModelName = "claude model " + randomAlphaOfLength(5); + return registerRemoteModel(claudeConnectorEntity, claudeModelName, true); + } + + private boolean awsCredentialsNotSet() { + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + log.info("#### The AWS credentials are not set. Skipping Claude tests. ####"); + return true; + } + return false; + } + + private void deleteModel(String modelId) throws IOException { + try { + // First try to undeploy the model + TestHelper.makeRequest(client(), "POST", "/_plugins/_ml/models/" + modelId + "/_undeploy", null, "", List.of()); + } catch (Exception e) { + log.info("Model {} might not be deployed, continuing with deletion", modelId); + } + + try { + // Then delete the model + log.info("Deleting model: {}", modelId); + TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/models/" + modelId, null, "", List.of()); + } catch (Exception e) { + log.warn("Failed to delete model: {}", modelId, e); + } + } +} From 00e54186d5593464a4e920b899d4a27ae8b29e19 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Mon, 8 Sep 2025 10:15:01 -0700 Subject: [PATCH 2/4] Address comments Signed-off-by: rithin-pullela-aws --- .../MemoryContainerConstants.java | 24 +++++++++---------- .../memory/MemoryProcessingService.java | 14 ++++++----- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java index ec7c6de438..1acbbc2da5 100644 --- a/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java @@ -140,18 +140,18 @@ public class MemoryContainerConstants { + "Do not make assumptions or inferences beyond what is explicitly stated.\n" + "If no personal information is found, return an empty list."; - public static final String PERSONAL_INFO_RESPONSE_FORMAT_SCHEMA = "" - + "You should always return and only return the extracted facts as a JSON object with a \"facts\" array.\n" - + "\n" - + "{\n" - + " \"facts\": [\n" - + " \"User's name is John Smith\",\n" - + " \"John works as a software engineer at TechCorp\",\n" - + " \"John enjoys hiking on weekends\",\n" - + " \"John is looking to improve his Python skills\"\n" - + " ]\n" - + "}\n" - + "\n"; + public static final String PERSONAL_INFO_RESPONSE_FORMAT_SCHEMA = + "You should always return and only return the extracted facts as a JSON object with a \"facts\" array.\n" + + "\n" + + "{\n" + + " \"facts\": [\n" + + " \"User's name is John Smith\",\n" + + " \"John works as a software engineer at TechCorp\",\n" + + " \"John enjoys hiking on weekends\",\n" + + " \"John is looking to improve his Python skills\"\n" + + " ]\n" + + "}\n" + + "\n"; public static final String PERSONAL_INFO_OUTPUT_FORMAT_INSTRUCTIONS = """ diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java index 6cd16f5214..5f386bd94c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java @@ -8,7 +8,10 @@ import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.*; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.DEFAULT_UPDATE_MEMORY_PROMPT; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.FACTS_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_DECISION_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.PERSONAL_INFORMATION_ORGANIZER_PROMPT; import java.util.ArrayList; import java.util.HashMap; @@ -64,11 +67,11 @@ public void extractFactsFromConversation( stringParameters.put("system_prompt", PERSONAL_INFORMATION_ORGANIZER_PROMPT); try { - StringBuilder user_messages = new StringBuilder(); + StringBuilder userMessages = new StringBuilder(); for (MessageInput message : messages) { - user_messages.append(message.getContent()); + userMessages.append(message.getContent()); } - String messagesJson = user_messages.toString(); + String messagesJson = userMessages.toString(); stringParameters.put("messages", escapeJson(messagesJson)); log.debug("LLM request - processing {} messages", messages.size()); @@ -190,11 +193,10 @@ private List parseFactsFromLLMResponse(MLOutput mlOutput) { try { Map dataMap = modelTensors.getMlModelTensors().get(0).getDataAsMap(); // if LLM Response does not contain FACTS_FIELD then fact extraction failed hence throw exception - if (dataMap == null || dataMap.isEmpty() || dataMap.get(FACTS_FIELD) == null) { + if (dataMap == null || dataMap.get(FACTS_FIELD) == null) { throw new IllegalArgumentException("Failed to parse facts from LLM response"); } facts = (List) dataMap.get(FACTS_FIELD); - facts = facts != null ? facts : new ArrayList<>(); } catch (Exception e) { // Should not print the user data in logs log.warn("Failed to parse facts from LLM response", e); From c640eebb8f791c08f0bf82ae417ebe2902c70064 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Wed, 10 Sep 2025 11:20:37 -0700 Subject: [PATCH 3/4] fix openAI key env variable typo Signed-off-by: rithin-pullela-aws --- .../test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java index d21835a890..2ebe139f2f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java @@ -24,7 +24,7 @@ @Log4j2 public class RestMLAgenticMemoryIT extends MLCommonsRestTestCase { - private static final String OPENAI_KEY = System.getProperty("OPENAI_KEY"); + private static final String OPENAI_KEY = System.getenv("OPENAI_KEY"); private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN"); From 7e9b85eccfe213c23cef4e7a348d5a8c4bac46b2 Mon Sep 17 00:00:00 2001 From: rithin-pullela-aws Date: Fri, 12 Sep 2025 18:49:35 -0700 Subject: [PATCH 4/4] Address Comments Signed-off-by: rithin-pullela-aws --- .../memorycontainer/memory/MemoryProcessingService.java | 1 + .../java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java index 5f386bd94c..c25da70480 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java +++ b/plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/MemoryProcessingService.java @@ -70,6 +70,7 @@ public void extractFactsFromConversation( StringBuilder userMessages = new StringBuilder(); for (MessageInput message : messages) { userMessages.append(message.getContent()); + userMessages.append(System.lineSeparator()); } String messagesJson = userMessages.toString(); stringParameters.put("messages", escapeJson(messagesJson)); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java index 2ebe139f2f..ccba5efd70 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLAgenticMemoryIT.java @@ -104,6 +104,9 @@ public void setup() throws IOException, InterruptedException { @Test public void testCreateMemoryContainerAndAddMemories_openAI() throws IOException, InterruptedException { + if (OPENAI_KEY == null) { + return; + } // Create OpenAI model and memory container String openaiModelId = registerLLMModel(); String memoryContainerId = createMemoryContainerWithModel( @@ -172,6 +175,9 @@ public void testCreateMemoryContainerAndAddMemories_openAI() throws IOException, @Test public void testUpdateMemoriesWithContradictoryInformation_openAI() throws IOException, InterruptedException { + if (OPENAI_KEY == null) { + return; + } // Create OpenAI model and memory container String openaiModelId = registerLLMModel(); String memoryContainerId = createMemoryContainerWithModel(