diff --git a/docker/requirements.txt b/docker/requirements.txt index d20c0b36e..4846f1832 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -157,4 +157,4 @@ volcengine-python-sdk==4.0.6 watchfiles==1.1.0 websockets==15.0.1 xlrd==2.0.2 -xlsxwriter==3.2.5 \ No newline at end of file +xlsxwriter==3.2.5 diff --git a/docs/openapi.json b/docs/openapi.json index 5a3471ac0..ee2ff1368 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -884,7 +884,7 @@ "type": "string", "title": "Session Id", "description": "Session ID for the MOS. This is used to distinguish between different dialogue", - "default": "0ce84b9c-0615-4b9d-83dd-fba50537d5d3" + "default": "41bb5e18-252d-4948-918c-07d82aa47086" }, "chat_model": { "$ref": "#/components/schemas/LLMConfigFactory", @@ -939,6 +939,12 @@ "description": "Enable parametric memory for the MemChat", "default": false }, + "enable_preference_memory": { + "type": "boolean", + "title": "Enable Preference Memory", + "description": "Enable preference memory for the MemChat", + "default": false + }, "enable_mem_scheduler": { "type": "boolean", "title": "Enable Mem Scheduler", diff --git a/evaluation/.env-example b/evaluation/.env-example index 4b2b9311f..bda935442 100644 --- a/evaluation/.env-example +++ b/evaluation/.env-example @@ -22,9 +22,13 @@ SUPERMEMORY_API_KEY="sm_xxx" MEMOBASE_API_KEY="xxx" MEMOBASE_PROJECT_URL="http://***.***.***.***:8019" -# eval settings -PRE_SPLIT_CHUNK=false - +# pref +PRE_SPLIT_CHUNK=false # pre split chunk in client end, for personamem and prefeval +# 1. text_mem + pref_mem + instruction_completion: set INSTRUCT_COMPLETE=true, ABLATION_PREF=false +# 2. text_mem + pref_mem: set INSTRUCT_COMPLETE=false, ABLATION_PREF=false +# 3. text_mem: set INSTRUCT_COMPLETE=false, ABLATION_PREF=true +INSTRUCT_COMPLETE=true # use instruct complete format or not +ABLATION_PREF=false # remove pref mem, only text mem # Configuration Only For Scheduler # RabbitMQ Configuration @@ -45,4 +49,4 @@ MEMSCHEDULER_GRAPHDBAUTH_URI=bolt://localhost:7687 MEMSCHEDULER_GRAPHDBAUTH_USER=neo4j MEMSCHEDULER_GRAPHDBAUTH_PASSWORD=*** MEMSCHEDULER_GRAPHDBAUTH_DB_NAME=neo4j -MEMSCHEDULER_GRAPHDBAUTH_AUTO_CREATE=true \ No newline at end of file +MEMSCHEDULER_GRAPHDBAUTH_AUTO_CREATE=true diff --git a/evaluation/scripts/PrefEval/pref_eval.py b/evaluation/scripts/PrefEval/pref_eval.py index 10cf41bf3..f1966b847 100644 --- a/evaluation/scripts/PrefEval/pref_eval.py +++ b/evaluation/scripts/PrefEval/pref_eval.py @@ -1,14 +1,18 @@ +import argparse import asyncio import json +import os import re -import argparse -from typing import Dict, Any + from collections import Counter -from tqdm.asyncio import tqdm -import os +from typing import Any + import pandas as pd + from dotenv import load_dotenv from openai import OpenAI +from tqdm.asyncio import tqdm + load_dotenv() @@ -30,7 +34,7 @@ async def call_gpt4o_mini_async(client: OpenAI, prompt: str) -> str: ) return response.choices[0].message.content except Exception as e: - print(f"API call failed: {str(e)}") + print(f"API call failed: {e!s}") return "" @@ -41,7 +45,7 @@ def parse_xml_response(response: str, tag: str) -> str: async def evaluate_violate_preference_async( client: OpenAI, preference: str, question: str, response: str -) -> Dict[str, str]: +) -> dict[str, str]: prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant's response violates the user's stated preference. Evaluate the response based on these criteria: @@ -73,8 +77,8 @@ async def evaluate_violate_preference_async( async def evaluate_acknowledge_preference_async( client: OpenAI, question: str, response: str -) -> Dict[str, str]: - prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant acknowledges any user preference in answering the user's query. +) -> dict[str, str]: + prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant acknowledges any user preference in answering the user's query. Evaluate the response based on these stringent criteria: @@ -82,7 +86,7 @@ async def evaluate_acknowledge_preference_async( Answer "Yes" if: -1. The response explicitly mentions or refers to a user preference in answering the user's question. Examples include: "Based on your previous preference for xxxx"; "Based on your previous preference, I would suggest you xxxx"; "Since you mentioned you prefer/dislike xxxx"; "Since you are a xxxx"; "I will recommend the following given your preference for xxx", etc. You should extract the "xxxx" in your answer. If it only says "Based on our previous conversation, I recommend..." and does not explicitly mention any preference, you should answer 'No'. +1. The response explicitly mentions or refers to a user preference in answering the user's question. Examples include: "Based on your previous preference for xxxx"; "Based on your previous preference, I would suggest you xxxx"; "Since you mentioned you prefer/dislike xxxx"; "Since you are a xxxx"; "I will recommend the following given your preference for xxx", etc. You should extract the "xxxx" in your answer. If it only says "Based on our previous conversation, I recommend..." and does not explicitly mention any preference, you should answer 'No'. 2. The response assumes the user preference in answering the user's question implicitly. For example, when the user asks 'Can you recommend me cars to drive?', if the response is 'Based on your preference, I will recommend non-EV cars, ...', then this indicates the assistant assumes that the user's preference is a dislike of EV cars, and you should answer 'Yes'. Answer "No" if the response does not mention or refer to any user preference explicitly or implicitly. If you cannot extract the sentence stating what the preference is, answer 'No'. @@ -103,7 +107,7 @@ async def evaluate_acknowledge_preference_async( async def evaluate_hallucinate_preference_async( client: OpenAI, preference: str, restatement: str -) -> Dict[str, str]: +) -> dict[str, str]: if not restatement.strip(): return {"explanation": "No restatement provided by assistant", "answer": "No"} prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant's restatement of the user's stated preference is the same preference. Evaluate the response based on these stringent criteria to answer if the assistant has hallucinated the preference or not: @@ -120,10 +124,10 @@ async def evaluate_hallucinate_preference_async( 2. The assistant's restatement is a minor paraphrase that fully preserves the meaning and intent of the original preference. 3. The restatement is just empty, no hallucination. -Here is the information: -Original user preference: "{preference}" -Assistant's restatement: "{restatement}" -Examine the original preference and the assistant's restatement meticulously and answer. Please answer in this exact XML format without any other additional text: +Here is the information: +Original user preference: "{preference}" +Assistant's restatement: "{restatement}" +Examine the original preference and the assistant's restatement meticulously and answer. Please answer in this exact XML format without any other additional text: [1 short sentence explanation] [Yes/No]""" @@ -136,7 +140,7 @@ async def evaluate_hallucinate_preference_async( async def evaluate_helpful_response_async( client: OpenAI, question: str, response: str -) -> Dict[str, str]: +) -> dict[str, str]: prompt = f"""You will analyze a conversation between a user and an assistant, focusing on whether the assistant provides any substantive response to the user's query. Evaluate the response based on these stringent criteria: @@ -174,7 +178,7 @@ async def evaluate_helpful_response_async( } -def classify_error_type(evaluation_results: Dict[str, Any]) -> str: +def classify_error_type(evaluation_results: dict[str, Any]) -> str: violate = evaluation_results["violate_preference"]["answer"] acknowledge = evaluation_results["acknowledge_preference"]["answer"] hallucinate = evaluation_results["hallucinate_preference"]["answer"] @@ -192,7 +196,7 @@ def classify_error_type(evaluation_results: Dict[str, Any]) -> str: return "Personalized Response" -async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) -> Dict[str, Any]: +async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) -> dict[str, Any]: async with semaphore: data = json.loads(line.strip()) preference = data["preference"] @@ -223,7 +227,7 @@ async def process_line(line: str, client: OpenAI, semaphore: asyncio.Semaphore) return result -def log_summary(error_counter: Counter, total_samples: int) -> Dict[str, Dict[str, float]]: +def log_summary(error_counter: Counter, total_samples: int) -> dict[str, dict[str, float]]: summary_data = {} print("\n--- Error Type Summary ---") @@ -247,7 +251,7 @@ def log_summary(error_counter: Counter, total_samples: int) -> Dict[str, Dict[st def generate_excel_summary( - summary_results: Dict[str, Dict[str, float]], + summary_results: dict[str, dict[str, float]], avg_search_time: float, avg_context_tokens: float, avg_add_time: float, @@ -317,7 +321,7 @@ async def main(concurrency_limit: int, input_file: str, output_file: str, output client = OpenAI(api_key=API_KEY, base_url=API_URL) try: - with open(input_file, "r", encoding="utf-8") as f: + with open(input_file, encoding="utf-8") as f: lines = f.readlines() except FileNotFoundError: print(f"Error: Input file not found at '{input_file}'") diff --git a/evaluation/scripts/PrefEval/pref_mem0.py b/evaluation/scripts/PrefEval/pref_mem0.py index 416d8045f..4bbdb0fd8 100644 --- a/evaluation/scripts/PrefEval/pref_mem0.py +++ b/evaluation/scripts/PrefEval/pref_mem0.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from irrelevant_conv import irre_10, irre_300 ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -199,7 +201,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_memobase.py b/evaluation/scripts/PrefEval/pref_memobase.py index 34d3ea86f..4f6174d3d 100644 --- a/evaluation/scripts/PrefEval/pref_memobase.py +++ b/evaluation/scripts/PrefEval/pref_memobase.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -import time -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -49,7 +51,7 @@ def add_memory_for_line( if conversation: messages = [] - for chunk_start in range(0, len(conversation)): + for chunk_start in range(len(conversation)): chunk = conversation[chunk_start : chunk_start + 1] timestamp_add = str(int(time.time() * 100)) time.sleep(0.001) # Ensure unique timestamp @@ -210,7 +212,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_memos.py b/evaluation/scripts/PrefEval/pref_memos.py index 5ee064b1f..753a77d99 100644 --- a/evaluation/scripts/PrefEval/pref_memos.py +++ b/evaluation/scripts/PrefEval/pref_memos.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from irrelevant_conv import irre_10, irre_300 ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -18,6 +20,8 @@ sys.path.insert(0, ROOT_DIR) sys.path.insert(0, EVAL_SCRIPTS_DIR) + + load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") BASE_URL = os.getenv("OPENAI_BASE_URL") @@ -68,6 +72,8 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di """ Processes a single line of data, searching memory based on the question. """ + from utils.pref_mem_utils import create_mem_string + i, line = line_data try: original_data = json.loads(line) @@ -88,9 +94,7 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di start_time_search = time.monotonic() relevant_memories = mem_client.search(query=question, user_id=user_id, top_k=top_k_value) search_memories_duration = time.monotonic() - start_time_search - memories_str = "\n".join( - f"- {entry.get('memory', '')}" for entry in relevant_memories["text_mem"][0]["memories"] - ) + memories_str = create_mem_string(relevant_memories) memory_tokens_used = len(tokenizer.encode(memories_str)) @@ -111,10 +115,13 @@ def search_memory_for_line(line_data: tuple, mem_client, top_k_value: int) -> di return None -def generate_response_for_line(line_data: tuple, openai_client: OpenAI) -> dict: +def generate_response_for_line(line_data: tuple, openai_client: OpenAI, lib: str) -> dict: """ Generates a response for a single line of data using pre-fetched memories. """ + from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string + from utils.prompts import PREFEVAL_ANSWER_PROMPT + i, line = line_data try: original_data = json.loads(line) @@ -139,7 +146,10 @@ def generate_response_for_line(line_data: tuple, openai_client: OpenAI) -> dict: ) return original_data - system_prompt = f"You are a helpful AI. Answer the question based on the query and the following memories:\nUser Memories:\n{memories_str}" + memories_str = remove_pref_mem_from_mem_string(memories_str, frame=lib) + + template = add_pref_instruction(PREFEVAL_ANSWER_PROMPT, frame=lib) + system_prompt = template.format(context=memories_str) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": question}, @@ -201,7 +211,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") @@ -277,7 +287,7 @@ def main(): concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor, ): futures = [ - executor.submit(generate_response_for_line, (i, line), openai_client) + executor.submit(generate_response_for_line, (i, line), openai_client, args.lib) for i, line in enumerate(lines) ] diff --git a/evaluation/scripts/PrefEval/pref_memu.py b/evaluation/scripts/PrefEval/pref_memu.py index 719f2b488..2b9f769a4 100644 --- a/evaluation/scripts/PrefEval/pref_memu.py +++ b/evaluation/scripts/PrefEval/pref_memu.py @@ -4,12 +4,16 @@ import os import sys import time + +from datetime import datetime + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from datetime import datetime -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -205,7 +209,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_supermemory.py b/evaluation/scripts/PrefEval/pref_supermemory.py index 85e84b6c9..88a64038b 100644 --- a/evaluation/scripts/PrefEval/pref_supermemory.py +++ b/evaluation/scripts/PrefEval/pref_supermemory.py @@ -4,12 +4,14 @@ import os import sys import time + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from datetime import datetime -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -201,7 +203,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/pref_zep.py b/evaluation/scripts/PrefEval/pref_zep.py index 699660787..91aef1492 100644 --- a/evaluation/scripts/PrefEval/pref_zep.py +++ b/evaluation/scripts/PrefEval/pref_zep.py @@ -4,12 +4,16 @@ import os import sys import time + +from datetime import datetime + import tiktoken + from dotenv import load_dotenv +from irrelevant_conv import irre_10, irre_300 from openai import OpenAI from tqdm import tqdm -from datetime import datetime -from irrelevant_conv import irre_10, irre_300 + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -211,7 +215,7 @@ def main(): args = parser.parse_args() try: - with open(args.input, "r", encoding="utf-8") as infile: + with open(args.input, encoding="utf-8") as infile: lines = infile.readlines() except FileNotFoundError: print(f"Error: Input file '{args.input}' not found") diff --git a/evaluation/scripts/PrefEval/prefeval_preprocess.py b/evaluation/scripts/PrefEval/prefeval_preprocess.py index 004d5e505..9ace9dec9 100644 --- a/evaluation/scripts/PrefEval/prefeval_preprocess.py +++ b/evaluation/scripts/PrefEval/prefeval_preprocess.py @@ -1,7 +1,8 @@ -from datasets import load_dataset import json import os +from datasets import load_dataset + def convert_dataset_to_jsonl(dataset_name, output_dir="./scripts/PrefEval"): if not os.path.exists(output_dir): @@ -64,7 +65,7 @@ def process_jsonl_file(input_filepath, output_filepath): line_count = 0 print(f"Start processing file: {input_filepath}") with ( - open(input_filepath, "r", encoding="utf-8") as infile, + open(input_filepath, encoding="utf-8") as infile, open(output_filepath, "w", encoding="utf-8") as outfile, ): for line in infile: diff --git a/evaluation/scripts/locomo/locomo_ingestion.py b/evaluation/scripts/locomo/locomo_ingestion.py index edb451dc0..fe7aa86f7 100644 --- a/evaluation/scripts/locomo/locomo_ingestion.py +++ b/evaluation/scripts/locomo/locomo_ingestion.py @@ -1,12 +1,16 @@ -import os -import sys import argparse import concurrent.futures +import os +import sys import time + from datetime import datetime, timezone + import pandas as pd + from dotenv import load_dotenv + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -88,8 +92,8 @@ def process_user(conv_idx, frame, locomo_df, version): client = None if frame == "mem0" or frame == "mem0_graph": - from utils.client import Mem0Client from prompts import custom_instructions + from utils.client import Mem0Client client = Mem0Client(enable_graph="graph" in frame) client.client.update_project(custom_instructions=custom_instructions) diff --git a/evaluation/scripts/locomo/locomo_responses.py b/evaluation/scripts/locomo/locomo_responses.py index 4e3b966a3..2ae4dcb6e 100644 --- a/evaluation/scripts/locomo/locomo_responses.py +++ b/evaluation/scripts/locomo/locomo_responses.py @@ -2,6 +2,7 @@ import asyncio import json import os +import sys from time import time @@ -13,6 +14,15 @@ from tqdm import tqdm +ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") + +sys.path.insert(0, ROOT_DIR) +sys.path.insert(0, EVAL_SCRIPTS_DIR) + + async def locomo_response(frame, llm_client, context: str, question: str) -> str: if frame == "zep": prompt = ANSWER_PROMPT_ZEP.format( @@ -25,7 +35,10 @@ async def locomo_response(frame, llm_client, context: str, question: str) -> str question=question, ) else: - prompt = ANSWER_PROMPT_MEMOS.format( + from utils.pref_mem_utils import add_pref_instruction + + template = add_pref_instruction(ANSWER_PROMPT_MEMOS, frame=frame) + prompt = template.format( context=context, question=question, ) @@ -42,12 +55,17 @@ async def locomo_response(frame, llm_client, context: str, question: str) -> str async def process_qa(frame, qa, search_result, oai_client): + from utils.pref_mem_utils import remove_pref_mem_from_mem_string + start = time() query = qa.get("question") gold_answer = qa.get("answer") qa_category = qa.get("category") - answer = await locomo_response(frame, oai_client, search_result.get("context"), query) + context = search_result.get("context") + + context = remove_pref_mem_from_mem_string(context, frame) + answer = await locomo_response(frame, oai_client, context, query) response_duration_ms = (time() - start) * 1000 diff --git a/evaluation/scripts/locomo/locomo_search.py b/evaluation/scripts/locomo/locomo_search.py index 452fb4762..19efb5b92 100644 --- a/evaluation/scripts/locomo/locomo_search.py +++ b/evaluation/scripts/locomo/locomo_search.py @@ -1,14 +1,18 @@ -import os -import sys import argparse import json +import os +import sys + from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from time import time + import pandas as pd + from dotenv import load_dotenv from tqdm import tqdm + ROOT_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) @@ -96,16 +100,14 @@ def memos_api_search( client, query, speaker_a_user_id, speaker_b_user_id, top_k, speaker_a, speaker_b ): from prompts import TEMPLATE_MEMOS + from utils.pref_mem_utils import create_mem_string start = time() search_a_results = client.search(query=query, user_id=speaker_a_user_id, top_k=top_k) search_b_results = client.search(query=query, user_id=speaker_b_user_id, top_k=top_k) - speaker_a_context = "\n".join( - [i["memory"] for i in search_a_results["text_mem"][0]["memories"]] - ) - speaker_b_context = "\n".join( - [i["memory"] for i in search_b_results["text_mem"][0]["memories"]] - ) + + speaker_a_context = create_mem_string(search_a_results) + speaker_b_context = create_mem_string(search_b_results) context = TEMPLATE_MEMOS.format( speaker_1=speaker_a, diff --git a/evaluation/scripts/locomo/prompts.py b/evaluation/scripts/locomo/prompts.py index 2827716a0..caf462f6a 100644 --- a/evaluation/scripts/locomo/prompts.py +++ b/evaluation/scripts/locomo/prompts.py @@ -1,3 +1,14 @@ +import os + + +PREF_INSTRUCTIONS = """ + # Note: + Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. + Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. + When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. +""" + + ANSWER_PROMPT_MEM0 = """ You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. @@ -49,12 +60,12 @@ 5. Always convert relative time references to specific dates, months, or years. 6. Be as specific as possible when talking about people, places, and events 7. Timestamps in memories represent the actual time the event occurred, not the time the event was mentioned in a message. - + Clarification: When interpreting memories, use the timestamp to determine when the described event happened, not when someone talked about the event. - + Example: - + Memory: (2023-03-15T16:33:00Z) I went to the vet yesterday. Question: What day did I go to the vet? Correct Answer: March 15, 2023 @@ -103,7 +114,7 @@ 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge). 6. Double-check that your answer directly addresses the question asked and adheres to all instructions. 7. Ensure your final answer is specific and avoids vague time references. - + {pref_instructions} {context} Question: {question} @@ -111,6 +122,11 @@ Answer: """ +if os.getenv("INSTRUCT_COMPLETE") == "true": + ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", PREF_INSTRUCTIONS) +else: + ANSWER_PROMPT_MEMOS = ANSWER_PROMPT_MEMOS.replace("{pref_instructions}", "") + custom_instructions = """ Generate personal memories that follow these guidelines: diff --git a/evaluation/scripts/longmemeval/lme_eval.py b/evaluation/scripts/longmemeval/lme_eval.py index 45c038a2b..73117b925 100644 --- a/evaluation/scripts/longmemeval/lme_eval.py +++ b/evaluation/scripts/longmemeval/lme_eval.py @@ -26,6 +26,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from utils.prompts import LME_JUDGE_MODEL_TEMPLATE + encoding = tiktoken.get_encoding("cl100k_base") logging.basicConfig(level=logging.CRITICAL) transformers.logging.set_verbosity_error() diff --git a/evaluation/scripts/longmemeval/lme_ingestion.py b/evaluation/scripts/longmemeval/lme_ingestion.py index a1849757d..325178292 100644 --- a/evaluation/scripts/longmemeval/lme_ingestion.py +++ b/evaluation/scripts/longmemeval/lme_ingestion.py @@ -1,11 +1,15 @@ import argparse import os import sys + from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timezone + import pandas as pd + from tqdm import tqdm + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -126,7 +130,7 @@ def main(frame, version, num_workers=2): success_records = [] record_file = f"results/lme/{frame}-{version}/success_records.txt" if os.path.exists(record_file): - with open(record_file, "r") as f: + with open(record_file) as f: for i in f.readlines(): success_records.append(i.strip()) diff --git a/evaluation/scripts/longmemeval/lme_responses.py b/evaluation/scripts/longmemeval/lme_responses.py index 3df3e2da4..22f17c304 100644 --- a/evaluation/scripts/longmemeval/lme_responses.py +++ b/evaluation/scripts/longmemeval/lme_responses.py @@ -12,16 +12,17 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string from utils.prompts import LME_ANSWER_PROMPT -def lme_response(llm_client, context, question, question_date): - prompt = LME_ANSWER_PROMPT.format( +def lme_response(llm_client, context, question, question_date, frame): + template = add_pref_instruction(LME_ANSWER_PROMPT, frame=frame) + prompt = template.format( question=question, question_date=question_date, context=context, ) - response = llm_client.chat.completions.create( model=os.getenv("CHAT_MODEL"), messages=[ @@ -34,13 +35,14 @@ def lme_response(llm_client, context, question, question_date): return result -def process_qa(user_id, search_result, llm_client): +def process_qa(user_id, search_result, llm_client, frame): start = time() search_result = search_result[0] question = search_result.get("question") question_date = search_result.get("date") context = search_result.get("search_context", "") - anwer = lme_response(llm_client, context, question, question_date) + context = remove_pref_mem_from_mem_string(context, frame=frame) + anwer = lme_response(llm_client, context, question, question_date, frame) response_duration_ms = (time() - start) * 1000 @@ -95,7 +97,7 @@ def main(frame, version, num_workers=4): future_to_user_id = {} for user_id, search_results in lme_search_results.items(): - future = executor.submit(process_qa, user_id, search_results, oai_client) + future = executor.submit(process_qa, user_id, search_results, oai_client, frame) future_to_user_id[future] = user_id for future in tqdm( diff --git a/evaluation/scripts/longmemeval/lme_search.py b/evaluation/scripts/longmemeval/lme_search.py index 67d2f1b04..d21795eef 100644 --- a/evaluation/scripts/longmemeval/lme_search.py +++ b/evaluation/scripts/longmemeval/lme_search.py @@ -3,6 +3,7 @@ import os import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed @@ -10,13 +11,13 @@ from time import time import pandas as pd + from tqdm import tqdm +from utils.pref_mem_utils import create_mem_string from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, - MEMOBASE_CONTEXT_TEMPLATE, MEMOS_CONTEXT_TEMPLATE, - ZEP_CONTEXT_TEMPLATE, ) @@ -44,7 +45,7 @@ def mem0_search(client, query, user_id, top_k): def memos_search(client, query, user_id, top_k): start = time() results = client.search(query=query, user_id=user_id, top_k=top_k) - context = "\n".join([i["memory"] for i in results["text_mem"][0]["memories"]]) + context = create_mem_string(results) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=context) duration_ms = (time() - start) * 1000 return context, duration_ms diff --git a/evaluation/scripts/personamem/pm_ingestion.py b/evaluation/scripts/personamem/pm_ingestion.py index 8de23937c..5204b5c2a 100644 --- a/evaluation/scripts/personamem/pm_ingestion.py +++ b/evaluation/scripts/personamem/pm_ingestion.py @@ -3,10 +3,13 @@ import json import os import sys +import time + from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime + from tqdm import tqdm -import time + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -34,7 +37,7 @@ def ingest_session(session, user_id, session_id, frame, client): client.add(messages=session, user_id=user_id, conv_id=session_id) print(f"[{frame}] ✅ Session [{session_id}]: Ingested {len(session)} messages") elif frame == "memobase": - for idx, msg in enumerate(session): + for _idx, msg in enumerate(session): if msg["role"] != "system": messages.append( { @@ -67,7 +70,7 @@ def build_jsonl_index(jsonl_path): Assumes each line is a JSON object with a single key-value pair. """ index = {} - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: while True: offset = f.tell() line = f.readline() @@ -79,14 +82,14 @@ def build_jsonl_index(jsonl_path): def load_context_by_id(jsonl_path, offset): - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: f.seek(offset) item = json.loads(f.readline()) return next(iter(item.values())) def load_rows(csv_path): - with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for _, row in enumerate(reader, start=1): row_data = {} @@ -98,7 +101,7 @@ def load_rows(csv_path): def load_rows_with_context(csv_path, jsonl_path): jsonl_index = build_jsonl_index(jsonl_path) - with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) prev_sid = None prev_context = None @@ -118,7 +121,7 @@ def load_rows_with_context(csv_path, jsonl_path): def count_csv_rows(csv_path): - with open(csv_path, mode="r", newline="", encoding="utf-8") as f: + with open(csv_path, newline="", encoding="utf-8") as f: return sum(1 for _ in f) - 1 diff --git a/evaluation/scripts/personamem/pm_metric.py b/evaluation/scripts/personamem/pm_metric.py index 653c5fc10..e88c538d4 100644 --- a/evaluation/scripts/personamem/pm_metric.py +++ b/evaluation/scripts/personamem/pm_metric.py @@ -44,7 +44,7 @@ def save_to_excel(results, output_path): category_row[f"response_{metric}"] = value # Add search duration metrics (if exists) - if "search_duration" in scores and scores["search_duration"]: + if scores.get("search_duration"): for metric, value in scores["search_duration"].items(): category_row[f"search_{metric}"] = value @@ -80,7 +80,7 @@ def calculate_scores(data, grade_path, output_path): print(f"📋 Processing response data for {len(data)} users...") # First pass: determine number of runs and initialize run accuracy arrays - for user_id, user_data in data.items(): + for _user_id, user_data in data.items(): # Skip incomplete data (users with only topic field) if len(user_data) <= 2 and "topic" in user_data: continue @@ -371,7 +371,7 @@ def print_summary(results): print(f"📂 Loading response data from: {responses_path}") try: - with open(responses_path, "r", encoding="utf-8") as file: + with open(responses_path, encoding="utf-8") as file: data = json.load(file) # Calculate metrics diff --git a/evaluation/scripts/personamem/pm_responses.py b/evaluation/scripts/personamem/pm_responses.py index 8bfeaf5f6..5b54f9bb8 100644 --- a/evaluation/scripts/personamem/pm_responses.py +++ b/evaluation/scripts/personamem/pm_responses.py @@ -10,11 +10,13 @@ from openai import OpenAI from tqdm import tqdm -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from utils.prompts import PM_ANSWER_PROMPT +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import re +from utils.pref_mem_utils import add_pref_instruction, remove_pref_mem_from_mem_string +from utils.prompts import PM_ANSWER_PROMPT + def extract_choice_answer(predicted_answer, correct_answer): def _extract_only_options(text): @@ -47,8 +49,9 @@ def _extract_only_options(text): return False, predicted_answer -def pm_response(llm_client, context, question, options): - prompt = PM_ANSWER_PROMPT.format( +def pm_response(llm_client, context, question, options, frame): + template = add_pref_instruction(PM_ANSWER_PROMPT, frame=frame) + prompt = template.format( question=question, context=context, options=options, @@ -65,17 +68,19 @@ def pm_response(llm_client, context, question, options): return result -def process_qa(user_id, search_result, num_runs, llm_client): +def process_qa(user_id, search_result, num_runs, llm_client, frame): search_result = search_result[0] question = search_result.get("question") context = search_result.get("search_context", "") options = search_result.get("all_options", []) + context = remove_pref_mem_from_mem_string(context, frame=frame) + run_results = [] for idx in range(num_runs): start = time() - answer = pm_response(llm_client, context, question, options) + answer = pm_response(llm_client, context, question, options, frame) is_correct, answer = extract_choice_answer(answer, search_result.get("golden_answer", "")) response_duration_ms = (time() - start) * 1000 @@ -149,7 +154,9 @@ def main(frame, version, num_runs=3, num_workers=4): future_to_user_id = {} for user_id, search_results in pm_search_results.items(): - future = executor.submit(process_qa, user_id, search_results, num_runs, oai_client) + future = executor.submit( + process_qa, user_id, search_results, num_runs, oai_client, frame + ) future_to_user_id[future] = user_id for future in tqdm( diff --git a/evaluation/scripts/personamem/pm_search.py b/evaluation/scripts/personamem/pm_search.py index 2e1a268fc..243c64589 100644 --- a/evaluation/scripts/personamem/pm_search.py +++ b/evaluation/scripts/personamem/pm_search.py @@ -1,16 +1,20 @@ import argparse +import csv import json import os import sys + from collections import defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from time import time + from tqdm import tqdm -import csv + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils.pref_mem_utils import create_mem_string from utils.prompts import ( MEM0_CONTEXT_TEMPLATE, MEM0_GRAPH_CONTEXT_TEMPLATE, @@ -79,9 +83,7 @@ def memobase_search(client, query, user_id, top_k): def memos_search(client, user_id, query, top_k): start = time() results = client.search(query=query, user_id=user_id, top_k=top_k) - search_memories = "\n".join( - item["memory"] for cube in results["text_mem"] for item in cube["memories"] - ) + search_memories = create_mem_string(results) context = MEMOS_CONTEXT_TEMPLATE.format(user_id=user_id, memories=search_memories) duration_ms = (time() - start) * 1000 @@ -109,7 +111,7 @@ def build_jsonl_index(jsonl_path): Assumes each line is a JSON object with a single key-value pair. """ index = {} - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: while True: offset = f.tell() line = f.readline() @@ -121,14 +123,14 @@ def build_jsonl_index(jsonl_path): def load_context_by_id(jsonl_path, offset): - with open(jsonl_path, "r", encoding="utf-8") as f: + with open(jsonl_path, encoding="utf-8") as f: f.seek(offset) item = json.loads(f.readline()) return next(iter(item.values())) def load_rows(csv_path): - with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for _, row in enumerate(reader, start=1): row_data = {} @@ -140,7 +142,7 @@ def load_rows(csv_path): def load_rows_with_context(csv_path, jsonl_path): jsonl_index = build_jsonl_index(jsonl_path) - with open(csv_path, mode="r", newline="", encoding="utf-8") as csvfile: + with open(csv_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) prev_sid = None prev_context = None @@ -163,7 +165,7 @@ def load_rows_with_context(csv_path, jsonl_path): def count_csv_rows(csv_path): - with open(csv_path, mode="r", newline="", encoding="utf-8") as f: + with open(csv_path, newline="", encoding="utf-8") as f: return sum(1 for _ in f) - 1 diff --git a/evaluation/scripts/run_pm_eval.sh b/evaluation/scripts/run_pm_eval.sh index f83893fed..a46440bfc 100755 --- a/evaluation/scripts/run_pm_eval.sh +++ b/evaluation/scripts/run_pm_eval.sh @@ -1,7 +1,7 @@ #!/bin/bash # Common parameters for all scripts -LIB="memu" +LIB="memos-api" VERSION="072202" WORKERS=10 TOPK=20 @@ -62,4 +62,4 @@ else fi fi -echo "All scripts completed successfully!" \ No newline at end of file +echo "All scripts completed successfully!" diff --git a/evaluation/scripts/run_prefeval_eval.sh b/evaluation/scripts/run_prefeval_eval.sh old mode 100644 new mode 100755 index 001f8299d..a79cefcc2 --- a/evaluation/scripts/run_prefeval_eval.sh +++ b/evaluation/scripts/run_prefeval_eval.sh @@ -11,13 +11,13 @@ WORKERS=10 # Parameters for pref_memos.py TOP_K=6 ADD_TURN=0 # Options: 0, 10, or 300 -LIB="memos-api" +LIB="memos-api" VERSION="1022-0" # --- File Paths --- # You may need to adjust these paths based on your project structure. # Step 1 (preprocess) outputs this file: -PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" +PREPROCESSED_FILE="data/prefeval/pref_processed.jsonl" # Create a directory name based on the *specific* LIB (e.g., "memos") OUTPUT_DIR="results/prefeval/${LIB}_${VERSION}" @@ -54,7 +54,7 @@ export HF_ENDPOINT="https://hf-mirror.com" echo "--- Starting PrefEval Pipeline ---" echo "Configuration: WORKERS=$WORKERS, TOP_K=$TOP_K, ADD_TURN=$ADD_TURN, LIB=$LIB, VERSION=$VERSION, HF_ENDPOINT=$HF_ENDPOINT" echo "Results will be saved to: $OUTPUT_DIR" -echo "Using script: $LIB_SCRIPT (mapped from LIB=$LIB)" +echo "Using script: $LIB_SCRIPT (mapped from LIB=$LIB)" echo "" # --- Step 1: Preprocess the data --- @@ -134,7 +134,7 @@ echo "Running pref_eval.py..." python scripts/PrefEval/pref_eval.py \ --input $RESPONSE_FILE \ --concurrency-limit $WORKERS - + if [ $? -ne 0 ]; then echo "Error: Evaluation script failed." exit 1 diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 2efb0493d..ffc9dda12 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -3,11 +3,15 @@ import sys import time import uuid + from contextlib import suppress from datetime import datetime -from dotenv import load_dotenv + import requests +from dotenv import load_dotenv + + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) load_dotenv() diff --git a/evaluation/scripts/utils/mirix_utils.py b/evaluation/scripts/utils/mirix_utils.py index e1b5f3de6..63cd490df 100644 --- a/evaluation/scripts/utils/mirix_utils.py +++ b/evaluation/scripts/utils/mirix_utils.py @@ -1,18 +1,21 @@ import os + import yaml + from tqdm import tqdm def get_mirix_client(config_path, load_from=None): - if os.path.exists(os.path.expanduser(f"~/.mirix")): - os.system(f"rm -rf ~/.mirix/*") + if os.path.exists(os.path.expanduser("~/.mirix")): + os.system("rm -rf ~/.mirix/*") - with open(config_path, "r") as f: + with open(config_path) as f: agent_config = yaml.safe_load(f) os.environ["OPENAI_API_KEY"] = agent_config["api_key"] import mirix - from mirix import Mirix, EmbeddingConfig, LLMConfig + + from mirix import EmbeddingConfig, LLMConfig, Mirix embedding_default_config = EmbeddingConfig( embedding_model=agent_config["embedding_model_name"], diff --git a/evaluation/scripts/utils/pref_mem_utils.py b/evaluation/scripts/utils/pref_mem_utils.py new file mode 100644 index 000000000..22a5bb86c --- /dev/null +++ b/evaluation/scripts/utils/pref_mem_utils.py @@ -0,0 +1,43 @@ +import os +import sys + + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from prompts import PREF_INSTRUCTIONS + + +def create_mem_string(relevant_memories) -> str: + text_memories = [] + explicit = [] + implicit = [] + for item in relevant_memories["text_mem"]: + for mem in item["memories"]: + text_memories.append(mem["memory"]) + text_memories_text = "\n".join(f"{i + 1}. {mem}" for i, mem in enumerate(text_memories)).strip() + text_context = f"Plaintext Memory:\n{text_memories_text}\n" if text_memories_text else "" + + for item in relevant_memories.get("prefs", []): + for mem in item["memories"]: + if mem["metadata"]["preference_type"] == "explicit_preference": + explicit.append(mem["metadata"]["explicit_preference"]) + elif mem["metadata"]["preference_type"] == "implicit_preference": + implicit.append(mem["metadata"]["implicit_preference"]) + explicit_text = "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(explicit)).strip() + explicit_context = f"Explicit Preference:\n{explicit_text}\n" if explicit_text else "" + implicit_text = "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(implicit)).strip() + implicit_context = f"Implicit Preference:\n{implicit_text}\n" if implicit_text else "" + return text_context + explicit_context + implicit_context + + +def remove_pref_mem_from_mem_string(mem_string: str, frame: str) -> str: + if os.getenv("ABLATION_PREF", "false").lower() == "true" and frame == "memos-api": + tmp_list = mem_string.split("Plaintext Memory:") + if len(tmp_list) > 1: + return tmp_list[1].split("Explicit Preference:")[0] + return mem_string + + +def add_pref_instruction(template: str, frame: str): + if os.getenv("INSTRUCT_COMPLETE", "false").lower() == "true" and frame == "memos-api": + return template.replace("{pref_instructions}", PREF_INSTRUCTIONS) + return template.replace("{pref_instructions}", "") diff --git a/evaluation/scripts/utils/prompts.py b/evaluation/scripts/utils/prompts.py index bd418af54..902bbb1be 100644 --- a/evaluation/scripts/utils/prompts.py +++ b/evaluation/scripts/utils/prompts.py @@ -1,3 +1,11 @@ +PREF_INSTRUCTIONS = """ + # Note: + Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. + Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. + When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. +""" + + LME_ANSWER_PROMPT = """ You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. @@ -17,7 +25,7 @@ 5. Formulate a precise, concise answer based solely on the evidence in the memories. 6. Double-check that your answer directly addresses the question asked. 7. Ensure your final answer is specific and avoids vague time references. - + {pref_instructions} {context} Current Date: {question_date} @@ -27,6 +35,7 @@ Answer: """ + PM_ANSWER_PROMPT = """ You are a helpful assistant tasked with selecting the best answer to a user question, based solely on summarized conversation memories. @@ -46,7 +55,7 @@ - Your final answer **must use parentheses**, like (a) or (b). - Do NOT list multiple choices. Choose only one. - Do NOT include extra text after . Just output the answer. - + {pref_instructions} # QUESTION: {question} @@ -58,6 +67,14 @@ """ +PREFEVAL_ANSWER_PROMPT = """ + You are a helpful AI. Answer the question based on the query and the following memories: + User Memories: + {context} + {pref_instructions} +""" + + ZEP_CONTEXT_TEMPLATE = """ FACTS and ENTITIES represent relevant context to the current conversation. diff --git a/examples/mem_os/simple_prefs_memos_product.py b/examples/mem_os/simple_prefs_memos_product.py new file mode 100644 index 000000000..40ec920f5 --- /dev/null +++ b/examples/mem_os/simple_prefs_memos_product.py @@ -0,0 +1,399 @@ +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.mem_os import MOSConfig +from memos.mem_cube.general import GeneralMemCube +from memos.mem_os.product import MOSProduct + + +def get_config(user_id: str): + llm_config = { + "backend": "openai", + "config": { + "model_name_or_path": "gpt-4o-mini", + "api_key": "sk-xxxxx", + "api_base": "http://xxxx/v1", + "temperature": 0.1, + "remove_think_prefix": True, + "max_tokens": 4096, + }, + } + + embedder_config = { + "backend": "ollama", + "config": {"model_name_or_path": "nomic-embed-text:latest"}, + } + + # init MOS + mos_config = { + "user_id": user_id, + "chat_model": llm_config, + "mem_reader": { + "backend": "simple_struct", + "config": { + "llm": llm_config, + "embedder": embedder_config, + "chunker": { + "backend": "sentence", + "config": { + "tokenizer_or_token_counter": "gpt2", + "chunk_size": 512, + "chunk_overlap": 128, + "min_sentences_per_chunk": 1, + }, + }, + }, + }, + "max_turns_window": 20, + "top_k": 5, + "enable_textual_memory": True, + "enable_activation_memory": False, + "enable_parametric_memory": False, + "enable_preference_memory": True, + } + + cube_config = { + "model_schema": "memos.configs.mem_cube.GeneralMemCubeConfig", + "user_id": user_id, + "cube_id": f"{user_id}/mem_cube", + "text_mem": { + "backend": "tree_text", + "config": { + "cube_id": f"{user_id}/mem_cube", + "extractor_llm": llm_config, + "dispatcher_llm": llm_config, + "graph_db": { + "backend": "neo4j", + "config": { + "uri": "bolt://localhost:7687", + "user": "neo4j", + "password": "12345678", + "db_name": "neo4j", + "user_name": "memosneo4j", + "embedding_dimension": 768, + "use_multi_db": False, + "auto_create": False, + }, + }, + "embedder": embedder_config, + }, + }, + "act_mem": {"backend": "uninitialized", "config": {}}, + "para_mem": {"backend": "uninitialized", "config": {}}, + "pref_mem": { + "backend": "pref_text", + "config": { + "cube_id": f"{user_id}/mem_cube", + "extractor_llm": llm_config, + "vector_db": { + "backend": "milvus", + "config": { + "collection_name": [ + "explicit_preference", + "implicit_preference", + ], + "vector_dimension": 768, + "distance_metric": "cosine", + "uri": "./milvus_demo.db", + }, + }, + "embedder": embedder_config, + "extractor": {"backend": "naive", "config": {}}, + "adder": {"backend": "naive", "config": {}}, + "retriever": {"backend": "naive", "config": {}}, + }, + }, + } + mos_config = MOSConfig(**mos_config) + cube_config = GeneralMemCubeConfig.model_validate(cube_config) + + return mos_config, cube_config + + +user_id = "superman" +mos_config, mem_cube_config = get_config(user_id) + +memory = MOSProduct(default_config=mos_config) + +# register user and mem cube +mem_cube = GeneralMemCube(mem_cube_config) + +register_result = memory.user_register( + user_id=user_id, user_name=user_id, default_mem_cube=mem_cube +) +print(f"register_result: {register_result}") + +# add memory for user +memory.add( + messages=[ + # Sports and hobbies + {"role": "user", "content": "I like playing football."}, + { + "role": "assistant", + "content": "I like playing football too. It's a great way to stay active and work as a team.", + }, + {"role": "user", "content": "Do you enjoy watching basketball games?"}, + { + "role": "assistant", + "content": "Yes, I find basketball exciting to watch, especially the fast-paced action and strategic plays.", + }, + {"role": "user", "content": "I'm learning to play tennis this summer."}, + { + "role": "assistant", + "content": "That's wonderful! Tennis is a fantastic sport that improves both physical fitness and mental focus.", + }, + # Food and cooking + {"role": "user", "content": "What's your favorite type of cuisine?"}, + { + "role": "assistant", + "content": "I really enjoy Italian cuisine, especially pasta dishes and authentic pizza.", + }, + {"role": "user", "content": "I love cooking Asian food at home."}, + { + "role": "assistant", + "content": "Asian cooking is amazing! The balance of flavors and techniques is so interesting to learn.", + }, + {"role": "user", "content": "Have you ever tried making homemade bread?"}, + { + "role": "assistant", + "content": "Yes, there's something magical about the process of making bread from scratch. The smell is incredible!", + }, + # Travel and places + {"role": "user", "content": "I'm planning a trip to Japan next year."}, + { + "role": "assistant", + "content": "Japan is an incredible destination! The culture, food, and technology blend is fascinating.", + }, + {"role": "user", "content": "What's the most beautiful place you've ever seen?"}, + { + "role": "assistant", + "content": "I find mountain landscapes particularly breathtaking, especially during sunrise or sunset.", + }, + {"role": "user", "content": "I prefer beach vacations over city trips."}, + { + "role": "assistant", + "content": "Beach vacations are so relaxing! The sound of waves and warm sand is perfect for unwinding.", + }, + # Technology and innovation + {"role": "user", "content": "What do you think about artificial intelligence?"}, + { + "role": "assistant", + "content": "AI is transforming many aspects of our lives. It's exciting to see how it can help solve complex problems.", + }, + {"role": "user", "content": "I'm learning to code in Python."}, + { + "role": "assistant", + "content": "Python is a great language to start with! It's versatile and has a wonderful community.", + }, + {"role": "user", "content": "Do you use any productivity apps?"}, + { + "role": "assistant", + "content": "I find task management apps really helpful for organizing daily activities and staying focused.", + }, + # Music and entertainment + {"role": "user", "content": "What kind of music do you enjoy?"}, + { + "role": "assistant", + "content": "I appreciate many genres, but I particularly enjoy jazz and classical music for their complexity.", + }, + {"role": "user", "content": "I play the guitar in my free time."}, + { + "role": "assistant", + "content": "That's fantastic! Playing an instrument is such a rewarding creative outlet.", + }, + {"role": "user", "content": "Have you seen any good movies lately?"}, + { + "role": "assistant", + "content": "I enjoy films that make you think deeply about human nature and society.", + }, + # Books and learning + {"role": "user", "content": "Do you read books regularly?"}, + { + "role": "assistant", + "content": "Yes, reading is one of my favorite ways to learn and explore new ideas.", + }, + {"role": "user", "content": "I'm currently reading a science fiction novel."}, + { + "role": "assistant", + "content": "Science fiction is great for exploring possibilities and imagining different futures.", + }, + {"role": "user", "content": "What's your favorite subject to learn about?"}, + { + "role": "assistant", + "content": "I'm fascinated by psychology and how the human mind works.", + }, + # Nature and environment + {"role": "user", "content": "I love hiking in the mountains."}, + { + "role": "assistant", + "content": "Hiking is wonderful! It's a great way to connect with nature and get exercise.", + }, + {"role": "user", "content": "Are you concerned about climate change?"}, + { + "role": "assistant", + "content": "Yes, it's one of the most important challenges of our time that requires collective action.", + }, + {"role": "user", "content": "I try to recycle and reduce waste at home."}, + { + "role": "assistant", + "content": "That's commendable! Every small action contributes to environmental protection.", + }, + # Work and career + {"role": "user", "content": "What do you do for work?"}, + { + "role": "assistant", + "content": "I work in technology, helping people solve problems and improve their daily lives.", + }, + {"role": "user", "content": "I'm considering a career change."}, + { + "role": "assistant", + "content": "Career changes can be challenging but also very rewarding when you find the right fit.", + }, + {"role": "user", "content": "Do you work from home or in an office?"}, + { + "role": "assistant", + "content": "I work remotely, which gives me flexibility in how I organize my day.", + }, + # Health and wellness + {"role": "user", "content": "I've started doing yoga in the morning."}, + { + "role": "assistant", + "content": "Yoga is excellent for both physical and mental well-being. Great way to start the day!", + }, + {"role": "user", "content": "How do you manage stress?"}, + { + "role": "assistant", + "content": "I find meditation and regular exercise help me stay balanced and focused.", + }, + {"role": "user", "content": "I'm trying to eat more vegetables."}, + { + "role": "assistant", + "content": "That's a great health goal! Vegetables provide so many essential nutrients.", + }, + # Art and creativity + {"role": "user", "content": "Do you enjoy painting or drawing?"}, + { + "role": "assistant", + "content": "I appreciate visual arts, though I'm more of an observer than a creator.", + }, + {"role": "user", "content": "I'm learning photography."}, + { + "role": "assistant", + "content": "Photography is a beautiful way to capture moments and express your perspective.", + }, + {"role": "user", "content": "What's your favorite type of art?"}, + { + "role": "assistant", + "content": "I'm drawn to abstract art because it allows for personal interpretation.", + }, + # Family and relationships + {"role": "user", "content": "I have two siblings."}, + { + "role": "assistant", + "content": "Having siblings can be wonderful - they're often our first friends and lifelong companions.", + }, + {"role": "user", "content": "Do you have any pets?"}, + { + "role": "assistant", + "content": "I don't have pets, but I think they bring so much joy and companionship to people's lives.", + }, + {"role": "user", "content": "I'm planning to visit my grandparents this weekend."}, + { + "role": "assistant", + "content": "That's lovely! Grandparents often have such interesting stories and wisdom to share.", + }, + # Education and growth + {"role": "user", "content": "I'm taking an online course in data science."}, + { + "role": "assistant", + "content": "Data science is a fascinating field with many applications across different industries.", + }, + {"role": "user", "content": "What's the best way to learn a new language?"}, + { + "role": "assistant", + "content": "Immersion and consistent practice are key, along with finding methods that work for your learning style.", + }, + {"role": "user", "content": "I enjoy attending workshops and conferences."}, + { + "role": "assistant", + "content": "Workshops are great for hands-on learning and networking with like-minded people.", + }, + # Hobbies and interests + {"role": "user", "content": "I collect vintage postcards."}, + { + "role": "assistant", + "content": "What a unique hobby! Postcards offer glimpses into different times and places.", + }, + {"role": "user", "content": "Do you garden?"}, + { + "role": "assistant", + "content": "I appreciate gardening, though I'm more of a plant observer than a gardener myself.", + }, + {"role": "user", "content": "I love solving puzzles and brain teasers."}, + { + "role": "assistant", + "content": "Puzzles are excellent for keeping the mind sharp and developing problem-solving skills.", + }, + # Future and goals + {"role": "user", "content": "What are your goals for next year?"}, + { + "role": "assistant", + "content": "I aim to continue learning and growing, helping more people with their challenges.", + }, + {"role": "user", "content": "I want to start my own business someday."}, + { + "role": "assistant", + "content": "Entrepreneurship is exciting! It requires courage, planning, and persistence.", + }, + {"role": "user", "content": "Do you believe in setting New Year's resolutions?"}, + { + "role": "assistant", + "content": "I think any time is good for setting positive goals and making positive changes.", + }, + # Philosophy and life + {"role": "user", "content": "What's the meaning of life to you?"}, + { + "role": "assistant", + "content": "I believe it's about learning, growing, and making positive connections with others.", + }, + {"role": "user", "content": "I try to practice gratitude daily."}, + { + "role": "assistant", + "content": "Gratitude is such a powerful practice that can transform your perspective on life.", + }, + {"role": "user", "content": "What's your biggest fear?"}, + { + "role": "assistant", + "content": "I think fear of not being able to help when someone really needs it would be challenging.", + }, + # Technology and daily life + {"role": "user", "content": "How has technology changed your daily routine?"}, + { + "role": "assistant", + "content": "Technology has made many tasks more efficient and connected me with people worldwide.", + }, + {"role": "user", "content": "I'm trying to reduce my screen time."}, + { + "role": "assistant", + "content": "That's a healthy goal! Finding balance with technology is important for well-being.", + }, + {"role": "user", "content": "Do you use social media?"}, + { + "role": "assistant", + "content": "I think social media can be a great tool for connection when used mindfully.", + }, + # Final thoughts + {"role": "user", "content": "This has been a great conversation!"}, + { + "role": "assistant", + "content": "I've really enjoyed our chat too! It's wonderful to explore so many different topics together.", + }, + ], + user_id=user_id, + mem_cube_id=register_result["default_cube_id"], +) + +retrieved_memories = memory.search(query="What do you like?", user_id=user_id) +print( + f"len_pref_memories: {len(retrieved_memories['pref_mem'][0]['memories'])}" + if retrieved_memories["pref_mem"] + else 0 +) diff --git a/poetry.lock b/poetry.lock index d34f964b6..44265bca8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -690,6 +690,30 @@ toml = ["tomli (>=2.0.0) ; python_version < \"3.11\""] trio = ["trio (>=0.10.0)"] yaml = ["pyyaml (>=6.0.1)"] +[[package]] +name = "datasketch" +version = "1.6.5" +description = "Probabilistic data structures for processing and searching very large datasets" +optional = true +python-versions = "*" +groups = ["main"] +markers = "extra == \"pref-mem\" or extra == \"all\"" +files = [ + {file = "datasketch-1.6.5-py3-none-any.whl", hash = "sha256:59311b2925b2f37536e9f7c2f46bbc25e8e54379c8635a3fa7ca55d2abb66d1b"}, + {file = "datasketch-1.6.5.tar.gz", hash = "sha256:ba2848cb74f23d6d3dd444cf24edcbc47b1c34a171b1803231793ed4d74d4fcf"}, +] + +[package.dependencies] +numpy = ">=1.11" +scipy = ">=1.0.0" + +[package.extras] +benchmark = ["SetSimilaritySearch (>=0.1.7)", "matplotlib (>=3.1.2)", "nltk (>=3.4.5)", "pandas (>=0.25.3)", "pyfarmhash (>=0.2.2)", "pyhash (>=0.9.3)", "scikit-learn (>=0.21.3)", "scipy (>=1.3.3)"] +cassandra = ["cassandra-driver (>=3.20)"] +experimental-aio = ["aiounittest ; python_version >= \"3.6\"", "motor ; python_version >= \"3.6\""] +redis = ["redis (>=2.10.0)"] +test = ["cassandra-driver (>=3.20)", "coverage", "mock (>=2.0.0)", "mockredispy", "nose (>=1.3.7)", "nose-exclude (>=0.5.0)", "pymongo (>=3.9.0)", "pytest", "redis (>=2.10.0)"] + [[package]] name = "defusedxml" version = "0.7.1" @@ -1222,7 +1246,7 @@ files = [ {file = "grpcio-1.73.1-cp39-cp39-win_amd64.whl", hash = "sha256:42f0660bce31b745eb9d23f094a332d31f210dcadd0fc8e5be7e4c62a87ce86b"}, {file = "grpcio-1.73.1.tar.gz", hash = "sha256:7fce2cd1c0c1116cf3850564ebfc3264fba75d3c74a7414373f1238ea365ef87"}, ] -markers = {main = "extra == \"all\""} +markers = {main = "extra == \"pref-mem\" or extra == \"all\""} [package.extras] protobuf = ["grpcio-tools (>=1.73.1)"] @@ -3241,7 +3265,7 @@ files = [ {file = "pandas-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:b4b0de34dc8499c2db34000ef8baad684cfa4cbd836ecee05f323ebfba348c7d"}, {file = "pandas-2.3.1.tar.gz", hash = "sha256:0a95b9ac964fe83ce317827f80304d37388ea77616b1425f0ae41c9d2d0d7bb2"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\""} +markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [package.dependencies] numpy = [ @@ -3560,7 +3584,7 @@ files = [ {file = "protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e"}, {file = "protobuf-6.31.1.tar.gz", hash = "sha256:d8cac4c982f0b957a4dc73a80e2ea24fab08e679c0de9deb835f4a12d69aca9a"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\""} +markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [[package]] name = "pycparser" @@ -3773,6 +3797,33 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pymilvus" +version = "2.6.2" +description = "Python Sdk for Milvus" +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"pref-mem\" or extra == \"all\"" +files = [ + {file = "pymilvus-2.6.2-py3-none-any.whl", hash = "sha256:933e447e09424d490dcf595053b01a7277dadea7ae3235cd704363bd6792509d"}, + {file = "pymilvus-2.6.2.tar.gz", hash = "sha256:b4802cc954de8f2d47bf8d6230e92196514dcb8a3726ba6098dc27909d4bc8e3"}, +] + +[package.dependencies] +grpcio = ">=1.66.2,<1.68.0 || >1.68.0,<1.68.1 || >1.68.1,<1.69.0 || >1.69.0,<1.70.0 || >1.70.0,<1.70.1 || >1.70.1,<1.71.0 || >1.71.0,<1.72.1 || >1.72.1,<1.73.0 || >1.73.0" +pandas = ">=1.2.4" +protobuf = ">=5.27.2" +python-dotenv = ">=1.0.1,<2.0.0" +setuptools = ">69" +ujson = ">=2.0.0" + +[package.extras] +bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "requests", "urllib3"] +dev = ["azure-storage-blob", "black", "grpcio (==1.66.2)", "grpcio-testing (==1.66.2)", "grpcio-tools (==1.66.2)", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "pytest (>=5.3.4)", "pytest-asyncio", "pytest-cov (>=5.0.0)", "pytest-timeout (>=1.3.4)", "requests", "ruff (>=0.12.9,<1)", "scipy", "urllib3"] +milvus-lite = ["milvus-lite (>=2.4.0) ; sys_platform != \"win32\""] +model = ["pymilvus.model (>=0.3.0)"] + [[package]] name = "pymysql" version = "1.1.2" @@ -3946,7 +3997,7 @@ files = [ {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, ] -markers = {main = "extra == \"tree-mem\" or extra == \"all\" or extra == \"mem-reader\""} +markers = {main = "extra == \"tree-mem\" or extra == \"all\" or extra == \"mem-reader\" or extra == \"pref-mem\""} [[package]] name = "pywin32" @@ -4955,7 +5006,7 @@ files = [ {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] -markers = {main = "extra == \"all\" and platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\" and extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} +markers = {main = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and (extra == \"all\" or extra == \"pref-mem\") or extra == \"pref-mem\" or extra == \"all\"", eval = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""} [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] @@ -5578,7 +5629,7 @@ files = [ {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, ] -markers = {main = "extra == \"mem-reader\" or extra == \"all\""} +markers = {main = "extra == \"mem-reader\" or extra == \"all\" or extra == \"pref-mem\""} [[package]] name = "ujson" @@ -6301,13 +6352,14 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["chonkie", "markitdown", "neo4j", "pika", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] +all = ["chonkie", "datasketch", "markitdown", "neo4j", "pika", "pymilvus", "pymysql", "qdrant-client", "redis", "schedule", "sentence-transformers", "torch", "volcengine-python-sdk"] mem-reader = ["chonkie", "markitdown"] mem-scheduler = ["pika", "redis"] mem-user = ["pymysql"] +pref-mem = ["datasketch", "pymilvus"] tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "d85cb8a08870d67df6e462610231f1e735ba5293bd3fe5b0c4a212b3ccff7b72" \ No newline at end of file +content-hash = "3f0d0c9a996f87d945ef8bf83eed3e20f8c420b6b39e12012d0147eda2bf4d38" diff --git a/pyproject.toml b/pyproject.toml index a03b9174b..3745582f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,12 @@ mem-reader = [ "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", # Markdown parser for various file formats ] +# PreferenceTextMemory +pref-mem = [ + "pymilvus (>=2.6.1,<3.0.0)", # Milvus Vector DB + "datasketch (>=1.6.5,<2.0.0)", # MinHash library +] + # All optional dependencies # Allow users to install with `pip install MemoryOS[all]` all = [ @@ -99,6 +105,8 @@ all = [ "pymysql (>=1.1.0,<2.0.0)", "chonkie (>=1.0.7,<2.0.0)", "markitdown[docx,pdf,pptx,xls,xlsx] (>=0.1.1,<0.2.0)", + "pymilvus (>=2.6.1,<3.0.0)", + "datasketch (>=1.6.5,<2.0.0)", # NOT exist in the above optional groups # Because they are either huge-size dependencies or infrequently used dependencies. diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d552369c5..d26672883 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -108,6 +108,25 @@ def get_activation_vllm_config() -> dict[str, Any]: }, } + @staticmethod + def get_preference_memory_config() -> dict[str, Any]: + """Get preference memory configuration.""" + return { + "backend": "pref_text", + "config": { + "extractor_llm": {"backend": "openai", "config": APIConfig.get_openai_config()}, + "vector_db": { + "backend": "milvus", + "config": APIConfig.get_milvus_config(), + }, + "embedder": APIConfig.get_embedder_config(), + "reranker": APIConfig.get_reranker_config(), + "extractor": {"backend": "naive", "config": {}}, + "adder": {"backend": "naive", "config": {}}, + "retriever": {"backend": "naive", "config": {}}, + }, + } + @staticmethod def get_reranker_config() -> dict[str, Any]: """Get embedder configuration.""" @@ -275,6 +294,20 @@ def get_nebular_config(user_id: str | None = None) -> dict[str, Any]: "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), } + @staticmethod + def get_milvus_config(): + return { + "collection_name": [ + "explicit_preference", + "implicit_preference", + ], + "vector_dimension": int(os.getenv("EMBEDDING_DIMENSION", 1024)), + "distance_metric": "cosine", + "uri": os.getenv("MILVUS_URI", "http://localhost:19530"), + "user_name": os.getenv("MILVUS_USER_NAME", "root"), + "password": os.getenv("MILVUS_PASSWORD", "12345678"), + } + @staticmethod def get_mysql_config() -> dict[str, Any]: """Get MySQL configuration.""" @@ -385,6 +418,8 @@ def get_product_default_config() -> dict[str, Any]: "enable_textual_memory": True, "enable_activation_memory": os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "true", + "enable_preference_memory": os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() + == "true", "top_k": int(os.getenv("MOS_TOP_K", "50")), "max_turns_window": int(os.getenv("MOS_MAX_TURNS_WINDOW", "20")), } @@ -414,6 +449,8 @@ def get_start_default_config() -> dict[str, Any]: "enable_textual_memory": True, "enable_activation_memory": os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "true", + "enable_preference_memory": os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() + == "true", "top_k": int(os.getenv("MOS_TOP_K", "5")), "chat_model": { "backend": os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai"), @@ -478,6 +515,8 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "enable_textual_memory": True, "enable_activation_memory": os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "true", + "enable_preference_memory": os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() + == "true", "top_k": 30, "max_turns_window": 20, } @@ -543,6 +582,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false" else APIConfig.get_activation_vllm_config(), "para_mem": {}, + "pref_mem": {} + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "false" + else APIConfig.get_preference_memory_config(), } ) else: @@ -605,6 +647,9 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: if os.getenv("ENABLE_ACTIVATION_MEMORY", "false").lower() == "false" else APIConfig.get_activation_vllm_config(), "para_mem": {}, + "pref_mem": {} + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "false" + else APIConfig.get_preference_memory_config(), } ) else: diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index d14c05993..e491e9feb 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -180,6 +180,7 @@ class APISearchRequest(BaseRequest): operation: list[PermissionDict] | None = Field( None, description="operation ids for multi cubes" ) + handle_pref_mem: bool = Field(False, description="Whether to handle preference memory") class APIADDRequest(BaseRequest): diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 9f982ddd3..d2392f927 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -2,6 +2,7 @@ import os import traceback +from concurrent.futures import ThreadPoolExecutor from typing import Any from fastapi import APIRouter, HTTPException @@ -21,6 +22,7 @@ from memos.configs.mem_reader import MemReaderConfigFactory from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.configs.reranker import RerankerConfigFactory +from memos.configs.vec_db import VectorDBConfigFactory from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory @@ -36,12 +38,24 @@ ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.memories.textual.prefer_text_memory.config import ( + AdderConfigFactory, + ExtractorConfigFactory, + RetrieverConfigFactory, +) +from memos.memories.textual.prefer_text_memory.factory import ( + AdderFactory, + ExtractorFactory, + RetrieverFactory, +) from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) from memos.reranker.factory import RerankerFactory +from memos.templates.instruction_completion import instruct_completion from memos.types import MOSSearchResult, UserContext +from memos.vec_dbs.factory import VecDBFactory logger = get_logger(__name__) @@ -66,6 +80,16 @@ def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: ) +def _build_vec_db_config() -> dict[str, Any]: + """Build vector database configuration.""" + return VectorDBConfigFactory.model_validate( + { + "backend": "milvus", + "config": APIConfig.get_milvus_config(), + } + ) + + def _build_llm_config() -> dict[str, Any]: """Build LLM configuration.""" return LLMConfigFactory.model_validate( @@ -98,6 +122,21 @@ def _build_internet_retriever_config() -> dict[str, Any]: return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) +def _build_pref_extractor_config() -> dict[str, Any]: + """Build extractor configuration.""" + return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}}) + + +def _build_pref_adder_config() -> dict[str, Any]: + """Build adder configuration.""" + return AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) + + +def _build_pref_retriever_config() -> dict[str, Any]: + """Build retriever configuration.""" + return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) + + def _get_default_memory_size(cube_config) -> dict[str, int]: """Get default memory size configuration.""" return getattr(cube_config.text_mem.config, "memory_size", None) or { @@ -120,9 +159,14 @@ def init_server(): mem_reader_config = _build_mem_reader_config() reranker_config = _build_reranker_config() internet_retriever_config = _build_internet_retriever_config() + vector_db_config = _build_vec_db_config() + pref_extractor_config = _build_pref_extractor_config() + pref_adder_config = _build_pref_adder_config() + pref_retriever_config = _build_pref_retriever_config() # Create component instances graph_db = GraphStoreFactory.from_config(graph_db_config) + vector_db = VecDBFactory.from_config(vector_db_config) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) mem_reader = MemReaderFactory.from_config(mem_reader_config) @@ -130,6 +174,25 @@ def init_server(): internet_retriever = InternetRetrieverFactory.from_config( internet_retriever_config, embedder=embedder ) + pref_extractor = ExtractorFactory.from_config( + config_factory=pref_extractor_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + ) + pref_adder = AdderFactory.from_config( + config_factory=pref_adder_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + ) + pref_retriever = RetrieverFactory.from_config( + config_factory=pref_retriever_config, + llm_provider=llm, + embedder=embedder, + reranker=reranker, + vector_db=vector_db, + ) # Initialize memory manager memory_manager = MemoryManager( @@ -170,6 +233,10 @@ def init_server(): internet_retriever=internet_retriever, memory_manager=memory_manager, default_cube_config=default_cube_config, + vector_db=vector_db, + pref_extractor=pref_extractor, + pref_adder=pref_adder, + pref_retriever=pref_retriever, ) return ( @@ -185,6 +252,10 @@ def init_server(): mem_scheduler, naive_mem_cube, api_module, + vector_db, + pref_extractor, + pref_adder, + pref_retriever, ) @@ -202,6 +273,10 @@ def init_server(): mem_scheduler, naive_mem_cube, api_module, + vector_db, + pref_extractor, + pref_adder, + pref_retriever, ) = init_server() @@ -221,6 +296,28 @@ def _format_memory_item(memory_data: Any) -> dict[str, Any]: return memory +def _post_process_pref_mem( + memories_result: list[dict[str, Any]], + pref_formatted_mem: list[dict[str, Any]], + mem_cube_id: str, + handle_pref_mem: bool, +): + if os.getenv("RETURN_ORIGINAL_PREF_MEM", "false").lower() == "true" and pref_formatted_mem: + memories_result["prefs"] = [] + memories_result["prefs"].append( + { + "cube_id": mem_cube_id, + "memories": pref_formatted_mem, + } + ) + + if handle_pref_mem: + pref_instruction: str = instruct_completion(pref_formatted_mem) + memories_result["pref_mem"] = pref_instruction + + return memories_result + + @router.post("/search", summary="Search memories", response_model=SearchResponse) def search_memories(search_req: APISearchRequest): """Search memories for a specific user.""" @@ -239,23 +336,55 @@ def search_memories(search_req: APISearchRequest): search_mode = search_req.mode - if search_mode == SearchMode.FAST: - formatted_memories = fast_search_memories(search_req=search_req, user_context=user_context) - elif search_mode == SearchMode.FINE: - formatted_memories = fine_search_memories(search_req=search_req, user_context=user_context) - elif search_mode == SearchMode.MIXTURE: - formatted_memories = mix_search_memories(search_req=search_req, user_context=user_context) - else: - logger.error(f"Unsupported search mode: {search_mode}") - raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}") + def _search_text(): + if search_mode == SearchMode.FAST: + formatted_memories = fast_search_memories( + search_req=search_req, user_context=user_context + ) + elif search_mode == SearchMode.FINE: + formatted_memories = fine_search_memories( + search_req=search_req, user_context=user_context + ) + elif search_mode == SearchMode.MIXTURE: + formatted_memories = mix_search_memories( + search_req=search_req, user_context=user_context + ) + else: + logger.error(f"Unsupported search mode: {search_mode}") + raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}") + return formatted_memories + + def _search_pref(): + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + results = naive_mem_cube.pref_mem.search( + query=search_req.query, + top_k=search_req.top_k, + info={ + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "chat_history": search_req.chat_history, + }, + ) + return [_format_memory_item(data) for data in results] + + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(_search_text) + pref_future = executor.submit(_search_pref) + text_formatted_memories = text_future.result() + pref_formatted_memories = pref_future.result() memories_result["text_mem"].append( { "cube_id": search_req.mem_cube_id, - "memories": formatted_memories, + "memories": text_formatted_memories, } ) + memories_result = _post_process_pref_mem( + memories_result, pref_formatted_memories, search_req.mem_cube_id, search_req.handle_pref_mem + ) + return SearchResponse( message="Search completed successfully", data=memories_result, @@ -431,38 +560,69 @@ def add_memories(add_req: APIADDRequest): target_session_id = add_req.session_id if not target_session_id: target_session_id = "default_session" - memories = mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - ) - # Flatten memory list - flattened_memories = [mm for m in memories for mm in m] - logger.info(f"Memory extraction completed for user {add_req.user_id}") - mem_id_list: list[str] = naive_mem_cube.text_mem.add( - flattened_memories, - user_name=user_context.mem_cube_id, - ) + def _process_text_mem() -> list[dict[str, str]]: + memories_local = mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + flattened_local = [mm for m in memories_local for mm in m] + logger.info(f"Memory extraction completed for user {add_req.user_id}") + mem_ids_local: list[str] = naive_mem_cube.text_mem.add( + flattened_local, + user_name=user_context.mem_cube_id, + ) + logger.info( + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) + ] + + def _process_pref_mem() -> list[dict[str, str]]: + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + pref_memories_local = naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) + logger.info( + f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " + f"in session {add_req.session_id}: {pref_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] + + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(_process_text_mem) + pref_future = executor.submit(_process_pref_mem) + text_response_data = text_future.result() + pref_response_data = pref_future.result() - logger.info( - f"Added {len(mem_id_list)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_id_list}" - ) - response_data = [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.memory_type, - } - for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False) - ] return MemoryResponse( message="Memory added successfully", - data=response_data, + data=text_response_data + pref_response_data, ) diff --git a/src/memos/configs/mem_cube.py b/src/memos/configs/mem_cube.py index b9868fa99..4bd709fab 100644 --- a/src/memos/configs/mem_cube.py +++ b/src/memos/configs/mem_cube.py @@ -54,6 +54,11 @@ class GeneralMemCubeConfig(BaseMemCubeConfig): default_factory=MemoryConfigFactory, description="Configuration for the parametric memory", ) + pref_mem: MemoryConfigFactory = Field( + ..., + default_factory=MemoryConfigFactory, + description="Configuration for the preference memory", + ) @field_validator("text_mem") @classmethod @@ -87,3 +92,14 @@ def validate_para_mem(cls, para_mem: MemoryConfigFactory) -> MemoryConfigFactory f"GeneralMemCubeConfig requires para_mem backend to be one of {allowed_backends}, got '{para_mem.backend}'" ) return para_mem + + @field_validator("pref_mem") + @classmethod + def validate_pref_mem(cls, pref_mem: MemoryConfigFactory) -> MemoryConfigFactory: + """Validate the pref_mem field.""" + allowed_backends = ["pref_text", "uninitialized"] + if pref_mem.backend not in allowed_backends: + raise ConfigurationError( + f"GeneralMemCubeConfig requires pref_mem backend to be one of {allowed_backends}, got '{pref_mem.backend}'" + ) + return pref_mem diff --git a/src/memos/configs/mem_os.py b/src/memos/configs/mem_os.py index 0645fce44..549e55792 100644 --- a/src/memos/configs/mem_os.py +++ b/src/memos/configs/mem_os.py @@ -58,6 +58,10 @@ class MOSConfig(BaseConfig): default=False, description="Enable parametric memory for the MemChat", ) + enable_preference_memory: bool = Field( + default=False, + description="Enable preference memory for the MemChat", + ) enable_mem_scheduler: bool = Field( default=False, description="Enable memory scheduler for automated memory management", diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 2c3a715f7..bf2493567 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -10,6 +10,11 @@ from memos.configs.reranker import RerankerConfigFactory from memos.configs.vec_db import VectorDBConfigFactory from memos.exceptions import ConfigurationError +from memos.memories.textual.prefer_text_memory.config import ( + AdderConfigFactory, + ExtractorConfigFactory, + RetrieverConfigFactory, +) # ─── 1. Global Base Memory Config ───────────────────────────────────────────── @@ -189,6 +194,45 @@ class SimpleTreeTextMemoryConfig(TreeTextMemoryConfig): """Simple tree text memory configuration class.""" +class PreferenceTextMemoryConfig(BaseTextMemoryConfig): + """Preference memory configuration class.""" + + extractor_llm: LLMConfigFactory = Field( + ..., + default_factory=LLMConfigFactory, + description="LLM configuration for the memory extractor", + ) + vector_db: VectorDBConfigFactory = Field( + ..., + default_factory=VectorDBConfigFactory, + description="Vector database configuration for the memory storage", + ) + embedder: EmbedderConfigFactory = Field( + ..., + default_factory=EmbedderConfigFactory, + description="Embedder configuration for the memory embedding", + ) + reranker: RerankerConfigFactory | None = Field( + None, + description="Reranker configuration (optional).", + ) + extractor: ExtractorConfigFactory = Field( + ..., + default_factory=ExtractorConfigFactory, + description="Extractor configuration for the memory extracting", + ) + adder: AdderConfigFactory = Field( + ..., + default_factory=AdderConfigFactory, + description="Adder configuration for the memory adding", + ) + retriever: RetrieverConfigFactory = Field( + ..., + default_factory=RetrieverConfigFactory, + description="Retriever configuration for the memory retrieving", + ) + + # ─── 3. Global Memory Config Factory ────────────────────────────────────────── @@ -203,6 +247,7 @@ class MemoryConfigFactory(BaseConfig): "general_text": GeneralTextMemoryConfig, "simple_tree_text": SimpleTreeTextMemoryConfig, "tree_text": TreeTextMemoryConfig, + "pref_text": PreferenceTextMemoryConfig, "kv_cache": KVCacheMemoryConfig, "vllm_kv_cache": KVCacheMemoryConfig, # Use same config as kv_cache "lora": LoRAMemoryConfig, diff --git a/src/memos/mem_cube/base.py b/src/memos/mem_cube/base.py index 7d7c5e779..349d511fb 100644 --- a/src/memos/mem_cube/base.py +++ b/src/memos/mem_cube/base.py @@ -19,6 +19,7 @@ def __init__(self, config: BaseMemCubeConfig): self.text_mem: BaseTextMemory self.act_mem: BaseActMemory self.para_mem: BaseParaMemory + self.pref_mem: BaseTextMemory @abstractmethod def load(self, dir: str) -> None: diff --git a/src/memos/mem_cube/general.py b/src/memos/mem_cube/general.py index 17e45809c..1238ae050 100644 --- a/src/memos/mem_cube/general.py +++ b/src/memos/mem_cube/general.py @@ -41,16 +41,23 @@ def __init__(self, config: GeneralMemCubeConfig): if config.para_mem.backend != "uninitialized" else None ) + self._pref_mem: BaseTextMemory | None = ( + MemoryFactory.from_config(config.pref_mem) + if config.pref_mem.backend != "uninitialized" + else None + ) def load( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + self, + dir: str, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Load memories. Args: dir (str): The directory containing the memory files. memory_types (list[str], optional): List of memory types to load. If None, loads all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] + Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] """ loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename)) if loaded_schema != self.config.model_schema: @@ -61,7 +68,7 @@ def load( # If no specific memory types specified, load all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Load specified memory types if "text_mem" in memory_types and self.text_mem: @@ -76,17 +83,23 @@ def load( self.para_mem.load(dir) logger.info(f"Loaded para_mem from {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.load(dir) + logger.info(f"Loaded pref_mem from {dir}") + logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") def dump( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + self, + dir: str, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Dump memories. Args: dir (str): The directory where the memory files will be saved. memory_types (list[str], optional): List of memory types to dump. If None, dumps all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] + Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] """ if os.path.exists(dir) and os.listdir(dir): raise MemCubeError( @@ -98,7 +111,7 @@ def dump( # If no specific memory types specified, dump all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Dump specified memory types if "text_mem" in memory_types and self.text_mem: @@ -113,12 +126,16 @@ def dump( self.para_mem.dump(dir) logger.info(f"Dumped para_mem to {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.dump(dir) + logger.info(f"Dumped pref_mem to {dir}") + logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})") @staticmethod def init_from_dir( dir: str, - memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, default_config: GeneralMemCubeConfig | None = None, ) -> "GeneralMemCube": """Create a MemCube instance from a MemCube directory. @@ -148,7 +165,7 @@ def init_from_dir( def init_from_remote_repo( cube_id: str, base_url: str = "https://huggingface.co/datasets", - memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, default_config: GeneralMemCubeConfig | None = None, ) -> "GeneralMemCube": """Create a MemCube instance from a remote repository. @@ -207,3 +224,17 @@ def para_mem(self, value: BaseParaMemory) -> None: if not isinstance(value, BaseParaMemory): raise TypeError(f"Expected BaseParaMemory, got {type(value).__name__}") self._para_mem = value + + @property + def pref_mem(self) -> "BaseTextMemory | None": + """Get the preference memory.""" + if self._pref_mem is None: + logger.warning("Preference memory is not initialized. Returning None.") + return self._pref_mem + + @pref_mem.setter + def pref_mem(self, value: BaseTextMemory) -> None: + """Set the preference memory.""" + if not isinstance(value, BaseTextMemory): + raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") + self._pref_mem = value diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py index 7ce3ca642..ba9f136b7 100644 --- a/src/memos/mem_cube/navie.py +++ b/src/memos/mem_cube/navie.py @@ -14,9 +14,14 @@ from memos.memories.activation.base import BaseActMemory from memos.memories.parametric.base import BaseParaMemory from memos.memories.textual.base import BaseTextMemory +from memos.memories.textual.prefer_text_memory.adder import BaseAdder +from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor +from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.reranker.base import BaseReranker +from memos.vec_dbs.base import BaseVecDB logger = get_logger(__name__) @@ -34,7 +39,11 @@ def __init__( reranker: BaseReranker, memory_manager: MemoryManager, default_cube_config: GeneralMemCubeConfig, + vector_db: BaseVecDB, internet_retriever: None = None, + pref_extractor: BaseExtractor | None = None, + pref_adder: BaseAdder | None = None, + pref_retriever: BaseRetriever | None = None, ): """Initialize the MemCube with a configuration.""" self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory( @@ -49,6 +58,15 @@ def __init__( ) self._act_mem: BaseActMemory | None = None self._para_mem: BaseParaMemory | None = None + self._pref_mem: BaseTextMemory | None = SimplePreferenceTextMemory( + extractor_llm=llm, + vector_db=vector_db, + embedder=embedder, + reranker=reranker, + extractor=pref_extractor, + adder=pref_adder, + retriever=pref_retriever, + ) def load( self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None @@ -69,7 +87,7 @@ def load( # If no specific memory types specified, load all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Load specified memory types if "text_mem" in memory_types and self.text_mem: @@ -84,17 +102,23 @@ def load( self.para_mem.load(dir) logger.info(f"Loaded para_mem from {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.load(dir) + logger.info(f"Loaded pref_mem from {dir}") + logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") def dump( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + self, + dir: str, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Dump memories. Args: dir (str): The directory where the memory files will be saved. memory_types (list[str], optional): List of memory types to dump. If None, dumps all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] + Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] """ if os.path.exists(dir) and os.listdir(dir): raise MemCubeError( @@ -106,7 +130,7 @@ def dump( # If no specific memory types specified, dump all if memory_types is None: - memory_types = ["text_mem", "act_mem", "para_mem"] + memory_types = ["text_mem", "act_mem", "para_mem", "pref_mem"] # Dump specified memory types if "text_mem" in memory_types and self.text_mem: @@ -121,6 +145,10 @@ def dump( self.para_mem.dump(dir) logger.info(f"Dumped para_mem to {dir}") + if "pref_mem" in memory_types and self.pref_mem: + self.pref_mem.dump(dir) + logger.info(f"Dumped pref_mem to {dir}") + logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})") @property @@ -164,3 +192,17 @@ def para_mem(self, value: BaseParaMemory) -> None: if not isinstance(value, BaseParaMemory): raise TypeError(f"Expected BaseParaMemory, got {type(value).__name__}") self._para_mem = value + + @property + def pref_mem(self) -> "BaseTextMemory | None": + """Get the preference memory.""" + if self._pref_mem is None: + logger.warning("Preference memory is not initialized. Returning None.") + return self._pref_mem + + @pref_mem.setter + def pref_mem(self, value: BaseTextMemory) -> None: + """Set the preference memory.""" + if not isinstance(value, BaseTextMemory): + raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") + self._pref_mem = value diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index cedffd6fb..ec8a673d7 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -2,6 +2,7 @@ import os import time +from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path from threading import Lock @@ -18,6 +19,7 @@ ADD_LABEL, ANSWER_LABEL, MEM_READ_LABEL, + PREF_ADD_LABEL, QUERY_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -590,6 +592,7 @@ def search( "text_mem": [], "act_mem": [], "para_mem": [], + "pref_mem": [], } if install_cube_ids is None: install_cube_ids = user_cube_ids @@ -604,33 +607,78 @@ def search( ) for mem_cube_id, mem_cube in tmp_mem_cubes.items(): - if ( - (mem_cube_id in install_cube_ids) - and (mem_cube.text_mem is not None) - and self.config.enable_textual_memory - ): - time_start = time.time() - memories = mem_cube.text_mem.search( - query, - top_k=top_k if top_k else self.config.top_k, - mode=mode, - manual_close_internet=not internet_search, - info={ - "user_id": target_user_id, - "session_id": target_session_id, - "chat_history": chat_history.chat_history, - }, - moscube=moscube, - search_filter=search_filter, - ) - result["text_mem"].append({"cube_id": mem_cube_id, "memories": memories}) - logger.info( - f"🧠 [Memory] Searched memories from {mem_cube_id}:\n{self._str_memories(memories)}\n" - ) - search_time_end = time.time() - logger.info( - f"time search graph: search graph time user_id: {target_user_id} time is: {search_time_end - time_start}" - ) + # Define internal functions for parallel search execution + def search_textual_memory(cube_id, cube): + if ( + (cube_id in install_cube_ids) + and (cube.text_mem is not None) + and self.config.enable_textual_memory + ): + time_start = time.time() + memories = cube.text_mem.search( + query, + top_k=top_k if top_k else self.config.top_k, + mode=mode, + manual_close_internet=not internet_search, + info={ + "user_id": target_user_id, + "session_id": target_session_id, + "chat_history": chat_history.chat_history, + }, + moscube=moscube, + search_filter=search_filter, + ) + search_time_end = time.time() + logger.info( + f"🧠 [Memory] Searched memories from {cube_id}:\n{self._str_memories(memories)}\n" + ) + logger.info( + f"time search graph: search graph time user_id: {target_user_id} time is: {search_time_end - time_start}" + ) + return {"cube_id": cube_id, "memories": memories} + return None + + def search_preference_memory(cube_id, cube): + if ( + (cube_id in install_cube_ids) + and (cube.pref_mem is not None) + and self.config.enable_preference_memory + ): + time_start = time.time() + memories = cube.pref_mem.search( + query, + top_k=top_k if top_k else self.config.top_k, + info={ + "user_id": target_user_id, + "session_id": self.session_id, + "chat_history": chat_history.chat_history, + }, + ) + search_time_end = time.time() + logger.info( + f"🧠 [Memory] Searched preferences from {cube_id}:\n{self._str_memories(memories)}\n" + ) + logger.info( + f"time search pref: search pref time user_id: {target_user_id} time is: {search_time_end - time_start}" + ) + return {"cube_id": cube_id, "memories": memories} + return None + + # Execute both search functions in parallel + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(search_textual_memory, mem_cube_id, mem_cube) + pref_future = executor.submit(search_preference_memory, mem_cube_id, mem_cube) + + # Wait for both tasks to complete and collect results + text_result = text_future.result() + pref_result = pref_future.result() + + # Add results to the main result dictionary + if text_result is not None: + result["text_mem"].append(text_result) + if pref_result is not None: + result["pref_mem"].append(pref_result) + return result def add( @@ -679,79 +727,111 @@ def add( f"time add: get mem_cube_id time user_id: {target_user_id} time is: {time.time() - time_start}" ) - time_start_0 = time.time() if mem_cube_id not in self.mem_cubes: raise ValueError(f"MemCube '{mem_cube_id}' is not loaded. Please register.") - logger.info( - f"time add: get mem_cube_id check in mem_cubes time user_id: {target_user_id} time is: {time.time() - time_start_0}" - ) + sync_mode = self.mem_cubes[mem_cube_id].text_mem.mode if sync_mode == "async": assert self.mem_scheduler is not None, ( "Mem-Scheduler must be working when use asynchronous memory adding." ) logger.debug(f"Mem-reader mode is: {sync_mode}") - time_start_1 = time.time() - if ( - (messages is not None) - and self.config.enable_textual_memory - and self.mem_cubes[mem_cube_id].text_mem - ): - logger.info( - f"time add: messages is not None and enable_textual_memory and text_mem is not None time user_id: {target_user_id} time is: {time.time() - time_start_1}" - ) - if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": - add_memory = [] - metadata = TextualMemoryMetadata( - user_id=target_user_id, session_id=target_session_id, source="conversation" - ) - for message in messages: - add_memory.append( - TextualMemoryItem(memory=message["content"], metadata=metadata) + def process_textual_memory(): + if ( + (messages is not None) + and self.config.enable_textual_memory + and self.mem_cubes[mem_cube_id].text_mem + ): + if self.mem_cubes[mem_cube_id].config.text_mem.backend != "tree_text": + add_memory = [] + metadata = TextualMemoryMetadata( + user_id=target_user_id, session_id=target_session_id, source="conversation" ) - self.mem_cubes[mem_cube_id].text_mem.add(add_memory) - else: - messages_list = [messages] - time_start_2 = time.time() - memories = self.mem_reader.get_memory( - messages_list, - type="chat", - info={"user_id": target_user_id, "session_id": target_session_id}, - mode="fast" if sync_mode == "async" else "fine", - ) - logger.info( - f"time add: get mem_reader time user_id: {target_user_id} time is: {time.time() - time_start_2}" - ) - memories_flatten = [m for m_list in memories for m in m_list] - mem_ids: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(memories_flatten) - logger.info( - f"Added memory user {target_user_id} to memcube {mem_cube_id}: {mem_ids}" - ) - # submit messages for scheduler - if self.enable_mem_scheduler and self.mem_scheduler is not None: - mem_cube = self.mem_cubes[mem_cube_id] - if sync_mode == "async": + for message in messages: + add_memory.append( + TextualMemoryItem(memory=message["content"], metadata=metadata) + ) + self.mem_cubes[mem_cube_id].text_mem.add(add_memory) + else: + messages_list = [messages] + memories = self.mem_reader.get_memory( + messages_list, + type="chat", + info={"user_id": target_user_id, "session_id": target_session_id}, + mode="fast" if sync_mode == "async" else "fine", + ) + memories_flatten = [m for m_list in memories for m in m_list] + mem_ids: list[str] = self.mem_cubes[mem_cube_id].text_mem.add(memories_flatten) + logger.info( + f"Added memory user {target_user_id} to memcube {mem_cube_id}: {mem_ids}" + ) + # submit messages for scheduler + if self.enable_mem_scheduler and self.mem_scheduler is not None: + mem_cube = self.mem_cubes[mem_cube_id] + if sync_mode == "async": + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) + message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, - label=MEM_READ_LABEL, + label=ADD_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) + def process_preference_memory(): + if ( + (messages is not None) + and self.config.enable_preference_memory + and self.mem_cubes[mem_cube_id].pref_mem + ): + messages_list = [messages] + mem_cube = self.mem_cubes[mem_cube_id] + if sync_mode == "sync": + pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( + messages_list, + type="chat", + info={"user_id": target_user_id, "session_id": self.session_id}, + ) + pref_ids = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories) + logger.info( + f"Added preferences user {target_user_id} to memcube {mem_cube_id}: {pref_ids}" + ) + elif sync_mode == "async": + assert self.mem_scheduler is not None, ( + "Mem-Scheduler must be working when use asynchronous memory adding." + ) message_item = ScheduleMessageItem( user_id=target_user_id, + session_id=target_session_id, mem_cube_id=mem_cube_id, mem_cube=mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) + # Execute both memory processing functions in parallel + with ThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(process_textual_memory) + pref_future = executor.submit(process_preference_memory) + + # Wait for both tasks to complete + text_future.result() + pref_future.result() + # user profile if ( (memory_content is not None) @@ -1030,7 +1110,7 @@ def load( load_dir: str, user_id: str | None = None, mem_cube_id: str | None = None, - memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Dump the MemCube to a dictionary. Args: diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 7e0ed9aef..fed8f7278 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -1443,6 +1443,24 @@ def search( reformat_memory_list.append({"cube_id": memory["cube_id"], "memories": memories_list}) logger.info(f"search memory list is : {reformat_memory_list}") search_result["text_mem"] = reformat_memory_list + + pref_memory_list = search_result["pref_mem"] + reformat_pref_memory_list = [] + for memory in pref_memory_list: + memories_list = [] + for data in memory["memories"]: + memories = data.model_dump() + memories["ref_id"] = f"[{memories['id'].split('-')[0]}]" + memories["metadata"]["embedding"] = [] + memories["metadata"]["sources"] = [] + memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]" + memories["metadata"]["id"] = memories["id"] + memories["metadata"]["memory"] = memories["memory"] + memories_list.append(memories) + reformat_pref_memory_list.append( + {"cube_id": memory["cube_id"], "memories": memories_list} + ) + search_result["pref_mem"] = reformat_pref_memory_list time_end = time.time() logger.info( f"time search: total time for user_id: {user_id} time is: {time_end - time_start}" diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 31bb9b3da..d84ebb242 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -12,6 +12,7 @@ DEFAULT_MAX_QUERY_KEY_WORDS, MEM_ORGANIZE_LABEL, MEM_READ_LABEL, + PREF_ADD_LABEL, QUERY_LABEL, WORKING_MEMORY_TYPE, MemCubeID, @@ -20,7 +21,9 @@ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english -from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.preference import PreferenceTextMemory +from memos.memories.textual.tree import TreeTextMemory logger = get_logger(__name__) @@ -40,6 +43,7 @@ def __init__(self, config: GeneralSchedulerConfig): ADD_LABEL: self._add_message_consumer, MEM_READ_LABEL: self._mem_read_message_consumer, MEM_ORGANIZE_LABEL: self._mem_reorganize_message_consumer, + PREF_ADD_LABEL: self._pref_add_message_consumer, } self.dispatcher.register_handlers(handlers) @@ -468,6 +472,48 @@ def _process_memories_with_reorganize( f"Error in _process_memories_with_reader: {traceback.format_exc()}", exc_info=True ) + def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: + logger.info(f"Messages {messages} assigned to {PREF_ADD_LABEL} handler.") + + def process_message(message: ScheduleMessageItem): + try: + user_id = message.user_id + session_id = message.session_id + mem_cube_id = message.mem_cube_id + mem_cube = message.mem_cube + content = message.content + messages_list = json.loads(content) + + logger.info(f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}") + + # Get the preference memory from the mem_cube + pref_mem = mem_cube.pref_mem + if not isinstance(pref_mem, PreferenceTextMemory): + logger.error(f"Expected PreferenceTextMemory but got {type(pref_mem).__name__}") + return + + # Use pref_mem.get_memory to process the memories + pref_memories = pref_mem.get_memory( + messages_list, type="chat", info={"user_id": user_id, "session_id": session_id} + ) + # Add pref_mem to vector db + pref_ids = pref_mem.add(pref_memories) + + logger.info( + f"Successfully processed and add preferences for user_id={user_id}, mem_cube_id={mem_cube_id}, pref_ids={pref_ids}" + ) + + except Exception as e: + logger.error(f"Error processing pref_add message: {e}", exc_info=True) + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + futures = [executor.submit(process_message, msg) for msg in messages] + for future in concurrent.futures.as_completed(futures): + try: + future.result() + except Exception as e: + logger.error(f"Thread task failed: {e}", exc_info=True) + def process_session_turn( self, queries: str | list[str], diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index f0868e8df..2bc7a3b98 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -20,7 +20,7 @@ class SearchMode(str, Enum): MEM_READ_LABEL = "mem_read" MEM_ORGANIZE_LABEL = "mem_organize" API_MIX_SEARCH_LABEL = "api_mix_search" - +PREF_ADD_LABEL = "pref_add" TreeTextMemory_SEARCH_METHOD = "tree_text_memory_search" TreeTextMemory_FINE_SEARCH_METHOD = "tree_text_memory_fine_search" diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index efdaa44ef..541d2486d 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -35,6 +35,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4())) user_id: str = Field(..., description="user id") + session_id: str | None = Field(default=None, description="session id") mem_cube_id: str = Field(..., description="memcube id") label: str = Field(..., description="Label of the schedule message") mem_cube: GeneralMemCube | str = Field(..., description="memcube for schedule") @@ -55,6 +56,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): "example": { "item_id": "123e4567-e89b-12d3-a456-426614174000", # Sample UUID "user_id": "user123", # Example user identifier + "session_id": "session123", # Example session identifier "mem_cube_id": "cube456", # Sample memory cube ID "label": "sample_label", # Demonstration label value "mem_cube": "obj of GeneralMemCube", # Added mem_cube example @@ -76,6 +78,7 @@ def to_dict(self) -> dict: return { "item_id": self.item_id, "user_id": self.user_id, + "session_id": self.session_id, "cube_id": self.mem_cube_id, "label": self.label, "cube": "Not Applicable", # Custom cube serialization @@ -90,6 +93,8 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": item_id=data.get("item_id", str(uuid4())), user_id=data["user_id"], mem_cube_id=data["cube_id"], + session_id=data["session_id"], + cube_id=data["cube_id"], label=data["label"], mem_cube="Not Applicable", # Custom cube deserialization content=data["content"], diff --git a/src/memos/memories/factory.py b/src/memos/memories/factory.py index bcf7fdd9b..5ba1c6726 100644 --- a/src/memos/memories/factory.py +++ b/src/memos/memories/factory.py @@ -10,6 +10,8 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.general import GeneralTextMemory from memos.memories.textual.naive import NaiveTextMemory +from memos.memories.textual.preference import PreferenceTextMemory +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree import TreeTextMemory @@ -22,6 +24,8 @@ class MemoryFactory(BaseMemory): "general_text": GeneralTextMemory, "tree_text": TreeTextMemory, "simple_tree_text": SimpleTreeTextMemory, + "pref_text": PreferenceTextMemory, + "simple_pref_text": SimplePreferenceTextMemory, "kv_cache": KVCacheMemory, "vllm_kv_cache": VLLMKVCacheMemory, "lora": LoRAMemory, diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 2da283d47..6d975cfd7 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -167,6 +167,20 @@ class SearchedTreeNodeTextualMemoryMetadata(TreeNodeTextualMemoryMetadata): ) +class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): + """Metadata for preference memory item.""" + + preference_type: Literal["explicit_preference", "implicit_preference"] = Field( + default="explicit_preference", description="Type of preference." + ) + dialog_id: str | None = Field(default=None, description="ID of the dialog.") + dialog_str: str | None = Field(default=None, description="String of the dialog.") + embedding: list[float] | None = Field(default=None, description="Vector of the dialog.") + explicit_preference: str | None = Field(default=None, description="Explicit preference.") + created_at: str | None = Field(default=None, description="Timestamp of the dialog.") + implicit_preference: str | None = Field(default=None, description="Implicit preference.") + + class TextualMemoryItem(BaseModel): """Represents a single memory item in the textual memory. @@ -180,6 +194,7 @@ class TextualMemoryItem(BaseModel): SearchedTreeNodeTextualMemoryMetadata | TreeNodeTextualMemoryMetadata | TextualMemoryMetadata + | PreferenceTextualMemoryMetadata ) = Field(default_factory=TextualMemoryMetadata) model_config = ConfigDict(extra="forbid") @@ -204,12 +219,15 @@ def _coerce_metadata(cls, v: Any): v, SearchedTreeNodeTextualMemoryMetadata | TreeNodeTextualMemoryMetadata - | TextualMemoryMetadata, + | TextualMemoryMetadata + | PreferenceTextualMemoryMetadata, ): return v if isinstance(v, dict): if v.get("relativity") is not None: return SearchedTreeNodeTextualMemoryMetadata(**v) + if v.get("preference_type") is not None: + return PreferenceTextualMemoryMetadata(**v) if any(k in v for k in ("sources", "memory_type", "embedding", "background", "usage")): return TreeNodeTextualMemoryMetadata(**v) return TextualMemoryMetadata(**v) diff --git a/src/memos/memories/textual/prefer_text_memory/__init__.py b/src/memos/memories/textual/prefer_text_memory/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py new file mode 100644 index 000000000..390f048ef --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -0,0 +1,284 @@ +import json + +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any + +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryItem +from memos.templates.prefer_complete_prompt import ( + NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT, + NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE, +) +from memos.vec_dbs.item import MilvusVecDBItem + + +logger = get_logger(__name__) + + +class BaseAdder(ABC): + """Abstract base class for adders.""" + + @abstractmethod + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the adder.""" + + @abstractmethod + def add(self, memories: list[TextualMemoryItem | dict[str, Any]], *args, **kwargs) -> list[str]: + """Add the instruct preference memories. + Args: + memories (list[TextualMemoryItem | dict[str, Any]]): The memories to add. + **kwargs: Additional keyword arguments. + Returns: + list[str]: List of added memory IDs. + """ + + +class NaiveAdder(BaseAdder): + """Naive adder.""" + + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the naive adder.""" + super().__init__(llm_provider, embedder, vector_db) + self.llm_provider = llm_provider + self.embedder = embedder + self.vector_db = vector_db + + def _judge_update_or_add_fast(self, old_msg: str, new_msg: str) -> bool: + """Judge if the new message expresses the same core content as the old message.""" + # Use the template prompt with placeholders + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT.replace("{old_information}", old_msg).replace( + "{new_information}", new_msg + ) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + response = result.get("is_same", False) + return response if isinstance(response, bool) else response == "true" + except Exception as e: + logger.error(f"Error in judge_update_or_add: {e}") + # Fallback to simple string comparison + return old_msg == new_msg + + def _judge_update_or_add_trace_op( + self, new_mem: str, retrieved_mems: str + ) -> dict[str, Any] | None: + prompt = NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE.replace("{new_memory}", new_mem).replace( + "{retrieved_memories}", retrieved_mems + ) + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + return result + except Exception as e: + logger.error(f"Error in judge_update_or_add_trace_op: {e}") + return None + + def _update_memory_op_trace( + self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + preference_type: str, + ) -> list[str] | str: + if not retrieved_memories: + payload = new_memory.to_dict()["metadata"] + fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + vec_db_item = MilvusVecDBItem( + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, + ) + self.vector_db.add(collection_name, [vec_db_item]) + return new_memory.id + + new_mem_input = { + "context_summary": new_memory.memory, + "preference": new_memory.metadata.explicit_preference + if preference_type == "explicit_preference" + else new_memory.metadata.implicit_preference, + } + retrieved_mem_inputs = [ + { + "id": mem.id, + "context_summary": mem.memory, + "preference": mem.payload[preference_type], + } + for mem in retrieved_memories + ] + + rsp = self._judge_update_or_add_trace_op( + new_mem=json.dumps(new_mem_input), retrieved_mems=json.dumps(retrieved_mem_inputs) + ) + if not rsp: + payload = new_memory.to_dict()["metadata"] + fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + vec_db_item = MilvusVecDBItem( + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, + ) + self.vector_db.add(collection_name, [vec_db_item]) + return new_memory.id + + def execute_op(op): + op_type = op["type"].lower() + if op_type == "add": + payload = new_memory.to_dict()["metadata"] + payload = { + k: v + for k, v in payload.items() + if k not in {"dialog_id", "dialog_str", "embedding"} + } + vec_db_item = MilvusVecDBItem( + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, + ) + self.vector_db.add(collection_name, [vec_db_item]) + return new_memory.id + elif op_type == "update": + payload = { + "preference_type": preference_type, + preference_type: op["new_preference"], + } + vec_db_item = MilvusVecDBItem( + id=op["target_id"], + memory=op["new_context_summary"], + vector=self.embedder.embed([op["new_context_summary"]])[0], + payload=payload, + ) + self.vector_db.update(collection_name, op["target_id"], vec_db_item) + return op["target_id"] + elif op_type == "delete": + self.vector_db.delete(collection_name, [op["target_id"]]) + return None + + with ThreadPoolExecutor(max_workers=min(len(rsp["trace"]), 5)) as executor: + future_to_op = {executor.submit(execute_op, op): op for op in rsp["trace"]} + added_ids = [] + for future in as_completed(future_to_op): + result = future.result() + if result is not None: + added_ids.append(result) + + return added_ids + + def _update_memory_fast( + self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + ) -> str: + payload = new_memory.to_dict()["metadata"] + fields_to_remove = {"dialog_id", "dialog_str", "embedding"} + payload = {k: v for k, v in payload.items() if k not in fields_to_remove} + vec_db_item = MilvusVecDBItem( + id=new_memory.id, + memory=new_memory.memory, + vector=new_memory.metadata.embedding, + payload=payload, + ) + recall = retrieved_memories[0] if retrieved_memories else None + if not recall or (recall.score is not None and recall.score < 0.5): + self.vector_db.add(collection_name, [vec_db_item]) + return new_memory.id + + old_msg_str = recall.memory + new_msg_str = new_memory.memory + is_same = self._judge_update_or_add_fast(old_msg=old_msg_str, new_msg=new_msg_str) + if is_same: + self.vector_db.delete(collection_name, [recall.id]) + self.vector_db.update(collection_name, new_memory.id, vec_db_item) + return new_memory.id + + def _update_memory( + self, + new_memory: TextualMemoryItem, + retrieved_memories: list[MilvusVecDBItem], + collection_name: str, + preference_type: str, + update_mode: str = "op_trace", + ) -> list[str] | str | None: + """Update the memory. + Args: + new_memory: TextualMemoryItem + retrieved_memories: list[MilvusVecDBItem] + collection_name: str + preference_type: str + update_mode: str, "op_trace" or "fast" + """ + if update_mode == "op_trace": + return self._update_memory_op_trace( + new_memory, retrieved_memories, collection_name, preference_type + ) + elif update_mode == "fast": + return self._update_memory_fast(new_memory, retrieved_memories, collection_name) + else: + raise ValueError(f"Invalid update mode: {update_mode}") + + def _process_single_memory(self, memory: TextualMemoryItem) -> list[str] | str | None: + """Process a single memory and return its ID if added successfully.""" + try: + pref_type_collection_map = { + "explicit_preference": "explicit_preference", + "implicit_preference": "implicit_preference", + } + preference_type = memory.metadata.preference_type + collection_name = pref_type_collection_map[preference_type] + + search_results = self.vector_db.search( + memory.metadata.embedding, + collection_name, + top_k=5, + filter={"user_id": memory.metadata.user_id}, + ) + search_results.sort(key=lambda x: x.score, reverse=True) + + return self._update_memory( + memory, search_results, collection_name, preference_type, update_mode="fast" + ) + + except Exception as e: + logger.error(f"Error processing memory {memory.id}: {e}") + return None + + def add( + self, + memories: list[TextualMemoryItem | dict[str, Any]], + max_workers: int = 8, + *args, + **kwargs, + ) -> list[str]: + """Add the instruct preference memories using thread pool for acceleration.""" + if not memories: + return [] + + added_ids = [] + with ThreadPoolExecutor(max_workers=min(max_workers, len(memories))) as executor: + future_to_memory = { + executor.submit(self._process_single_memory, memory): memory for memory in memories + } + + for future in as_completed(future_to_memory): + try: + memory_id = future.result() + if memory_id: + if isinstance(memory_id, list): + added_ids.extend(memory_id) + else: + added_ids.append(memory_id) + except Exception as e: + memory = future_to_memory[future] + logger.error(f"Error processing memory {memory.id}: {e}") + continue + + return added_ids diff --git a/src/memos/memories/textual/prefer_text_memory/config.py b/src/memos/memories/textual/prefer_text_memory/config.py new file mode 100644 index 000000000..7e8354747 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/config.py @@ -0,0 +1,106 @@ +from typing import Any, ClassVar + +from pydantic import Field, field_validator, model_validator + +from memos.configs.base import BaseConfig + + +class BaseAdderConfig(BaseConfig): + """Base configuration class for Adder.""" + + +class NaiveAdderConfig(BaseAdderConfig): + """Configuration for Naive Adder.""" + + # No additional config needed since components are passed from parent + + +class AdderConfigFactory(BaseConfig): + """Factory class for creating Adder configurations.""" + + backend: str = Field(..., description="Backend for Adder") + config: dict[str, Any] = Field(..., description="Configuration for the Adder backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveAdderConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "AdderConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self + + +class BaseExtractorConfig(BaseConfig): + """Base configuration class for Extractor.""" + + +class NaiveExtractorConfig(BaseExtractorConfig): + """Configuration for Naive Extractor.""" + + +class ExtractorConfigFactory(BaseConfig): + """Factory class for creating Extractor configurations.""" + + backend: str = Field(..., description="Backend for Extractor") + config: dict[str, Any] = Field(..., description="Configuration for the Extractor backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveExtractorConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "ExtractorConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self + + +class BaseRetrieverConfig(BaseConfig): + """Base configuration class for Retrievers.""" + + +class NaiveRetrieverConfig(BaseRetrieverConfig): + """Configuration for Naive Retriever.""" + + +class RetrieverConfigFactory(BaseConfig): + """Factory class for creating Retriever configurations.""" + + backend: str = Field(..., description="Backend for Retriever") + config: dict[str, Any] = Field(..., description="Configuration for the Retriever backend") + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveRetrieverConfig, + } + + @field_validator("backend") + @classmethod + def validate_backend(cls, backend: str) -> str: + """Validate the backend field.""" + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + return backend + + @model_validator(mode="after") + def create_config(self) -> "RetrieverConfigFactory": + config_class = self.backend_to_class[self.backend] + self.config = config_class(**self.config) + return self diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py new file mode 100644 index 000000000..460b31f4f --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -0,0 +1,184 @@ +import json +import uuid + +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from typing import Any + +from memos.log import get_logger +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.prefer_text_memory.spliter import Splitter +from memos.memories.textual.prefer_text_memory.utils import convert_messages_to_string +from memos.templates.prefer_complete_prompt import ( + NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT, + NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT, +) +from memos.types import MessageList + + +logger = get_logger(__name__) + + +class BaseExtractor(ABC): + """Abstract base class for extractors.""" + + @abstractmethod + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the extractor.""" + + +class NaiveExtractor(BaseExtractor): + """Extractor.""" + + def __init__(self, llm_provider=None, embedder=None, vector_db=None): + """Initialize the extractor.""" + super().__init__(llm_provider, embedder, vector_db) + self.llm_provider = llm_provider + self.embedder = embedder + self.vector_db = vector_db + self.splitter = Splitter() + + def extract_basic_info(self, qa_pair: MessageList) -> dict[str, Any]: + """Extract basic information from a QA pair (no LLM needed).""" + basic_info = { + "dialog_id": str(uuid.uuid4()), + "dialog_str": convert_messages_to_string(qa_pair), + "created_at": datetime.now().isoformat(), + } + + return basic_info + + def extract_explicit_preference(self, qa_pair: MessageList | str) -> dict[str, Any] | None: + """Extract explicit preference from a QA pair.""" + qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair + prompt = NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + return result + except Exception as e: + logger.error(f"Error extracting explicit preference: {e}, return None") + return None + + def extract_implicit_preference(self, qa_pair: MessageList | str) -> dict[str, Any] | None: + """Extract implicit preferences from cluster qa pairs.""" + if not qa_pair: + return None + qa_pair_str = convert_messages_to_string(qa_pair) if isinstance(qa_pair, list) else qa_pair + prompt = NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT.replace("{qa_pair}", qa_pair_str) + + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + return result + except Exception as e: + logger.error(f"Error extracting implicit preferences: {e}, return None") + return None + + def _process_single_chunk_explicit( + self, chunk: MessageList, msg_type: str, info: dict[str, Any] + ) -> TextualMemoryItem | None: + """Process a single chunk and return a TextualMemoryItem.""" + basic_info = self.extract_basic_info(chunk) + if not basic_info["dialog_str"]: + return None + + explicit_pref = self.extract_explicit_preference(basic_info["dialog_str"]) + if not explicit_pref: + return None + + memories = [] + for pref in explicit_pref: + vector_info = { + "embedding": self.embedder.embed([pref["context_summary"]])[0], + } + extract_info = {**basic_info, **pref, **vector_info, **info} + + metadata = PreferenceTextualMemoryMetadata( + type=msg_type, preference_type="explicit_preference", **extract_info + ) + memory = TextualMemoryItem( + id=str(uuid.uuid4()), memory=pref["context_summary"], metadata=metadata + ) + + memories.append(memory) + + return memories + + def _process_single_chunk_implicit( + self, chunk: MessageList, msg_type: str, info: dict[str, Any] + ) -> TextualMemoryItem | None: + basic_info = self.extract_basic_info(chunk) + if not basic_info["dialog_str"]: + return None + implicit_pref = self.extract_implicit_preference(basic_info["dialog_str"]) + if not implicit_pref: + return None + + vector_info = { + "embedding": self.embedder.embed([implicit_pref["context_summary"]])[0], + } + + extract_info = {**basic_info, **implicit_pref, **vector_info, **info} + + metadata = PreferenceTextualMemoryMetadata( + type=msg_type, preference_type="implicit_preference", **extract_info + ) + memory = TextualMemoryItem( + id=extract_info["dialog_id"], memory=implicit_pref["context_summary"], metadata=metadata + ) + + return memory + + def extract( + self, + messages: list[MessageList], + msg_type: str, + info: dict[str, Any], + max_workers: int = 10, + ) -> list[TextualMemoryItem]: + """Extract preference memories based on the messages using thread pool for acceleration.""" + chunks: list[MessageList] = [] + for message in messages: + chunk = self.splitter.split_chunks(message, split_type="overlap") + chunks.extend(chunk) + if not chunks: + return [] + + memories = [] + with ThreadPoolExecutor(max_workers=min(max_workers, len(chunks))) as executor: + futures = { + executor.submit(self._process_single_chunk_explicit, chunk, msg_type, info): ( + "explicit", + chunk, + ) + for chunk in chunks + } + futures.update( + { + executor.submit(self._process_single_chunk_implicit, chunk, msg_type, info): ( + "implicit", + chunk, + ) + for chunk in chunks + } + ) + + for future in as_completed(futures): + try: + memory = future.result() + if memory: + if isinstance(memory, list): + memories.extend(memory) + else: + memories.append(memory) + except Exception as e: + task_type, chunk = futures[future] + logger.error(f"Error processing {task_type} chunk: {chunk}\n{e}") + continue + + return memories diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py new file mode 100644 index 000000000..22182261a --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -0,0 +1,78 @@ +from typing import Any, ClassVar + +from memos.memories.textual.prefer_text_memory.adder import BaseAdder, NaiveAdder +from memos.memories.textual.prefer_text_memory.config import ( + AdderConfigFactory, + ExtractorConfigFactory, + RetrieverConfigFactory, +) +from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor, NaiveExtractor +from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever, NaiveRetriever + + +class AdderFactory(BaseAdder): + """Factory class for creating Adder instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveAdder, + } + + @classmethod + def from_config( + cls, config_factory: AdderConfigFactory, llm_provider=None, embedder=None, vector_db=None + ) -> BaseAdder: + """Create a Adder instance from a configuration factory.""" + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + adder_class = cls.backend_to_class[backend] + return adder_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + + +class ExtractorFactory(BaseExtractor): + """Factory class for creating Extractor instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveExtractor, + } + + @classmethod + def from_config( + cls, + config_factory: ExtractorConfigFactory, + llm_provider=None, + embedder=None, + vector_db=None, + ) -> BaseExtractor: + """Create a Extractor instance from a configuration factory.""" + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + extractor_class = cls.backend_to_class[backend] + return extractor_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + + +class RetrieverFactory(BaseRetriever): + """Factory class for creating Retriever instances.""" + + backend_to_class: ClassVar[dict[str, Any]] = { + "naive": NaiveRetriever, + } + + @classmethod + def from_config( + cls, + config_factory: RetrieverConfigFactory, + llm_provider=None, + embedder=None, + reranker=None, + vector_db=None, + ) -> BaseRetriever: + """Create a Retriever instance from a configuration factory.""" + backend = config_factory.backend + if backend not in cls.backend_to_class: + raise ValueError(f"Invalid backend: {backend}") + retriever_class = cls.backend_to_class[backend] + return retriever_class( + llm_provider=llm_provider, embedder=embedder, reranker=reranker, vector_db=vector_db + ) diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py new file mode 100644 index 000000000..7f70bac3b --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -0,0 +1,88 @@ +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem + + +class BaseRetriever(ABC): + """Abstract base class for retrievers.""" + + @abstractmethod + def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=None): + """Initialize the retriever.""" + + @abstractmethod + def retrieve( + self, query: str, top_k: int, info: dict[str, Any] | None = None + ) -> list[TextualMemoryItem]: + """Retrieve memories from the retriever.""" + + +class NaiveRetriever(BaseRetriever): + """Naive retriever.""" + + def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=None): + """Initialize the naive retriever.""" + super().__init__(llm_provider, embedder, reranker, vector_db) + self.reranker = reranker + self.vector_db = vector_db + self.embedder = embedder + + def retrieve( + self, query: str, top_k: int, info: dict[str, Any] | None = None + ) -> list[TextualMemoryItem]: + """Retrieve memories from the naive retriever.""" + # TODO: un-support rewrite query and session filter now + if info: + info = info.copy() # Create a copy to avoid modifying the original + info.pop("chat_history", None) + info.pop("session_id", None) + query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings + query_embedding = query_embeddings[0] # Get the first (and only) embedding + + # Use thread pool to parallelize the searches + with ThreadPoolExecutor(max_workers=2) as executor: + # Submit all search tasks + future_explicit = executor.submit( + self.vector_db.search, query_embedding, "explicit_preference", top_k * 2, info + ) + future_implicit = executor.submit( + self.vector_db.search, query_embedding, "implicit_preference", top_k * 2, info + ) + + # Wait for all results + explicit_prefs = future_explicit.result() + implicit_prefs = future_implicit.result() + + # sort by score + explicit_prefs.sort(key=lambda x: x.score, reverse=True) + implicit_prefs.sort(key=lambda x: x.score, reverse=True) + + explicit_prefs = [ + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**pref.payload), + ) + for pref in explicit_prefs + if pref.payload["explicit_preference"] + ] + + implicit_prefs = [ + TextualMemoryItem( + id=pref.id, + memory=pref.memory, + metadata=PreferenceTextualMemoryMetadata(**pref.payload), + ) + for pref in implicit_prefs + if pref.payload["implicit_preference"] + ] + + if self.reranker: + explicit_prefs = self.reranker.rerank(query, explicit_prefs, top_k) + implicit_prefs = self.reranker.rerank(query, implicit_prefs, top_k) + explicit_prefs = [item for item, _ in explicit_prefs] + implicit_prefs = [item for item, _ in implicit_prefs] + + return explicit_prefs + implicit_prefs diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py new file mode 100644 index 000000000..59a6b0052 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/spliter.py @@ -0,0 +1,132 @@ +import copy + +from memos.chunkers import ChunkerFactory +from memos.configs.chunker import ChunkerConfigFactory +from memos.configs.parser import ParserConfigFactory +from memos.parsers.factory import ParserFactory +from memos.types import MessageList + + +class Splitter: + """Splitter.""" + + def __init__( + self, + lookback_turns: int = 1, + chunk_size: int = 256, + chunk_overlap: int = 128, + min_sentences_per_chunk: int = 1, + tokenizer: str = "gpt2", + parser_backend: str = "markitdown", + chunker_backend: str = "sentence", + ): + """Initialize the splitter.""" + self.lookback_turns = lookback_turns + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.min_sentences_per_chunk = min_sentences_per_chunk + self.tokenizer = tokenizer + self.chunker_backend = chunker_backend + self.parser_backend = parser_backend + # Initialize parser + parser_config = ParserConfigFactory.model_validate( + { + "backend": self.parser_backend, + "config": {}, + } + ) + self.parser = ParserFactory.from_config(parser_config) + + # Initialize chunker + chunker_config = ChunkerConfigFactory.model_validate( + { + "backend": self.chunker_backend, + "config": { + "tokenizer_or_token_counter": self.tokenizer, + "chunk_size": self.chunk_size, + "chunk_overlap": self.chunk_overlap, + "min_sentences_per_chunk": self.min_sentences_per_chunk, + }, + } + ) + self.chunker = ChunkerFactory.from_config(chunker_config) + + def _split_with_lookback(self, data: MessageList) -> list[MessageList]: + """Split the messages or files into chunks by looking back fixed number of turns. + adjacent chunk with high duplicate rate, + default lookback turns is 1, only current turn in chunk""" + # Build QA pairs from chat history + pairs = self.build_qa_pairs(data) + chunks = [] + + # Create chunks by looking back fixed number of turns + for i in range(len(pairs)): + # Calculate the start index for lookback + start_idx = max(0, i + 1 - self.lookback_turns) + # Get the chunk of pairs (as many as available, up to lookback_turns) + chunk_pairs = pairs[start_idx : i + 1] + + # Flatten chunk_pairs (list[list[dict]]) to MessageList (list[dict]) + chunk_messages = [] + for pair in chunk_pairs: + chunk_messages.extend(pair) + + chunks.append(chunk_messages) + return chunks + + def _split_with_overlap(self, data: MessageList) -> list[MessageList]: + """split the messages or files into chunks with overlap. + adjacent chunk with low duplicate rate""" + chunks = [] + chunk = [] + for item in data: + chunk.append(item) + # 5 turns (Q + A = 10) each chunk + if len(chunk) >= 10: + chunks.append(chunk) + # overlap 1 turns (Q + A = 2) + context = copy.deepcopy(chunk[-2:]) + chunk = context + if chunk: + chunks.append(chunk) + + return chunks + + def split_chunks(self, data: MessageList | str, **kwargs) -> list[MessageList] | list[str]: + """Split the messages or files into chunks. + + Args: + data: MessageList or string to split + + Returns: + List of MessageList chunks or list of string chunks + """ + if isinstance(data, list): + if kwargs.get("split_type") == "lookback": + chunks = self._split_with_lookback(data) + elif kwargs.get("split_type") == "overlap": + chunks = self._split_with_overlap(data) + return chunks + else: + # Parse and chunk the string data using pre-initialized components + text = self.parser.parse(data) + chunks = self.chunker.chunk(text) + + return [chunk.text for chunk in chunks] + + def build_qa_pairs(self, chat_history: MessageList) -> list[MessageList]: + """Build QA pairs from chat history.""" + qa_pairs = [] + current_qa_pair = [] + + for message in chat_history: + if message["role"] == "user": + current_qa_pair.append(message) + elif message["role"] == "assistant": + if not current_qa_pair: + continue + current_qa_pair.append(message) + qa_pairs.append(current_qa_pair.copy()) + current_qa_pair = [] # reset + + return qa_pairs diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py new file mode 100644 index 000000000..85adc9304 --- /dev/null +++ b/src/memos/memories/textual/prefer_text_memory/utils.py @@ -0,0 +1,70 @@ +import re + +from memos.dependency import require_python_package +from memos.memories.textual.item import TextualMemoryItem +from memos.types import MessageList + + +def convert_messages_to_string(messages: MessageList) -> str: + """Convert a list of messages to a string.""" + message_text = "" + for message in messages: + if message["role"] == "user": + message_text += f"Query: {message['content']}\n" if message["content"].strip() else "" + elif message["role"] == "assistant": + message_text += f"Answer: {message['content']}\n" if message["content"].strip() else "" + message_text = message_text.strip() + return message_text + + +@require_python_package( + import_name="datasketch", + install_command="pip install datasketch", + install_link="https://github.com/ekzhu/datasketch", +) +def deduplicate_preferences( + prefs: list[TextualMemoryItem], similarity_threshold: float = 0.6, num_perm: int = 256 +) -> list[TextualMemoryItem]: + """ + Deduplicate preference texts using MinHash algorithm. + + Args: + prefs: List of preference memory items to deduplicate + similarity_threshold: Jaccard similarity threshold (0.0-1.0), default 0.8 + + Returns: + Deduplicated list of preference items + """ + from datasketch import MinHash, MinHashLSH + + if not prefs: + return prefs + + # Use MinHashLSH for efficient similarity search + lsh = MinHashLSH(threshold=similarity_threshold, num_perm=num_perm) + unique_prefs = [] + + for i, pref in enumerate(prefs): + # Extract preference text + if hasattr(pref.metadata, "implicit_preference") and pref.metadata.implicit_preference: + text = pref.metadata.implicit_preference + elif hasattr(pref.metadata, "explicit_preference") and pref.metadata.explicit_preference: + text = pref.metadata.explicit_preference + else: + text = pref.memory + + # Create MinHash from text tokens + minhash = MinHash(num_perm=num_perm) + # Simple tokenization: split by whitespace and clean + tokens = re.findall(r"\w+", text.lower()) + for token in tokens: + minhash.update(token.encode("utf8")) + + # Check for duplicates using LSH + similar_items = lsh.query(minhash) + + if not similar_items: # No similar items found + lsh.insert(i, minhash) + unique_prefs.append(pref) + + return unique_prefs diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py new file mode 100644 index 000000000..5f85aa907 --- /dev/null +++ b/src/memos/memories/textual/preference.py @@ -0,0 +1,283 @@ +import json +import os + +from typing import Any + +from memos.configs.memory import PreferenceTextMemoryConfig +from memos.embedders.factory import ( + ArkEmbedder, + EmbedderFactory, + OllamaEmbedder, + SenTranEmbedder, + UniversalAPIEmbedder, +) +from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM +from memos.log import get_logger +from memos.memories.textual.base import BaseTextMemory +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.prefer_text_memory.factory import ( + AdderFactory, + ExtractorFactory, + RetrieverFactory, +) +from memos.reranker.factory import RerankerFactory +from memos.types import MessageList +from memos.vec_dbs.factory import MilvusVecDB, QdrantVecDB, VecDBFactory +from memos.vec_dbs.item import VecDBItem + + +logger = get_logger(__name__) + + +class PreferenceTextMemory(BaseTextMemory): + """Preference textual memory implementation for storing and retrieving memories.""" + + def __init__(self, config: PreferenceTextMemoryConfig): + """Initialize memory with the given configuration.""" + self.config: PreferenceTextMemoryConfig = config + self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( + config.extractor_llm + ) + self.vector_db: MilvusVecDB | QdrantVecDB = VecDBFactory.from_config(config.vector_db) + self.embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder = ( + EmbedderFactory.from_config(config.embedder) + ) + self.reranker = RerankerFactory.from_config(config.reranker) + + self.extractor = ExtractorFactory.from_config( + config.extractor, + llm_provider=self.extractor_llm, + embedder=self.embedder, + vector_db=self.vector_db, + ) + + self.adder = AdderFactory.from_config( + config.adder, + llm_provider=self.extractor_llm, + embedder=self.embedder, + vector_db=self.vector_db, + ) + self.retriever = RetrieverFactory.from_config( + config.retriever, + llm_provider=self.extractor_llm, + embedder=self.embedder, + reranker=self.reranker, + vector_db=self.vector_db, + ) + + def get_memory( + self, messages: list[MessageList], type: str, info: dict[str, Any] + ) -> list[TextualMemoryItem]: + """Get memory based on the messages. + Args: + messages (list[MessageList]): The messages to get memory from. + type (str): The type of memory to get. + info (dict[str, Any]): The info to get memory. + """ + return self.extractor.extract(messages, type, info) + + def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: + """Search for memories based on a query. + Args: + query (str): The query to search for. + top_k (int): The number of top results to return. + info (dict): Leave a record of memory consumption. + Returns: + list[TextualMemoryItem]: List of matching memories. + """ + return self.retriever.retrieve(query, top_k, info) + + def load(self, dir: str) -> None: + """Load memories from the specified directory. + Args: + dir (str): The directory containing the memory files. + """ + # For preference memory, we don't need to load from files + # as the data is stored in the vector database + try: + memory_file = os.path.join(dir, self.config.memory_filename) + + if not os.path.exists(memory_file): + logger.warning(f"Memory file not found: {memory_file}") + return + + with open(memory_file, encoding="utf-8") as f: + memories = json.load(f) + for collection_name, items in memories.items(): + vec_db_items = [VecDBItem.from_dict(m) for m in items] + self.vector_db.add(collection_name, vec_db_items) + logger.info(f"Loaded {len(items)} memories from {collection_name} in {memory_file}") + + except FileNotFoundError: + logger.error(f"Memory file not found in directory: {dir}") + except json.JSONDecodeError as e: + if e.pos == 0 and "Expecting value" in str(e): + logger.warning(f"Memory file is empty or contains only whitespace: {memory_file}") + else: + logger.error(f"Error decoding JSON from memory file: {e}") + except Exception as e: + logger.error(f"An error occurred while loading memories: {e}") + + def dump(self, dir: str) -> None: + """Dump memories to the specified directory. + Args: + dir (str): The directory where the memory files will be saved. + """ + # For preference memory, we don't need to dump to files + # as the data is stored in the vector database + try: + json_memories = {} + for collection_name in self.vector_db.config.collection_name: + items = self.vector_db.get_all(collection_name) + json_memories[collection_name] = [memory.to_dict() for memory in items] + + os.makedirs(dir, exist_ok=True) + memory_file = os.path.join(dir, self.config.memory_filename) + with open(memory_file, "w", encoding="utf-8") as f: + json.dump(json_memories, f, indent=4, ensure_ascii=False) + + logger.info( + f"Dumped {len(json_memories)} collections, {sum(len(items) for items in json_memories.values())} memories to {memory_file}" + ) + + except Exception as e: + logger.error(f"An error occurred while dumping memories: {e}") + raise + + def extract(self, messages: MessageList) -> list[TextualMemoryItem]: + """Extract memories based on the messages. + Args: + messages (MessageList): The messages to extract memories from. + Returns: + list[TextualMemoryItem]: List of extracted memory items. + """ + raise NotImplementedError + + def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: + """Add memories. + + Args: + memories: List of TextualMemoryItem objects or dictionaries to add. + """ + return self.adder.add(memories) + + def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: + """Update a memory by memory_id.""" + raise NotImplementedError + + def get(self, memory_id: str) -> TextualMemoryItem: + """Get a memory by its ID. + Args: + memory_id (str): The ID of the memory to retrieve. + Returns: + TextualMemoryItem: The memory with the given ID. + """ + raise NotImplementedError + + def get_with_collection_name( + self, collection_name: str, memory_id: str + ) -> TextualMemoryItem | None: + """Get a memory by its ID and collection name. + Args: + memory_id (str): The ID of the memory to retrieve. + collection_name (str): The name of the collection to retrieve the memory from. + Returns: + TextualMemoryItem: The memory with the given ID and collection name. + """ + try: + res = self.vector_db.get_by_id(collection_name, memory_id) + if res is None: + return None + return TextualMemoryItem( + id=res.id, + memory=res.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**res.payload), + ) + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError( + f"Memory with ID {memory_id} not found in collection {collection_name}: {e}" + ) from e + + def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: + """Get memories by their IDs. + Args: + memory_ids (list[str]): List of memory IDs to retrieve. + Returns: + list[TextualMemoryItem]: List of memories with the specified IDs. + """ + raise NotImplementedError + + def get_by_ids_with_collection_name( + self, collection_name: str, memory_ids: list[str] + ) -> list[TextualMemoryItem]: + """Get memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to retrieve the memory from. + memory_ids (list[str]): List of memory IDs to retrieve. + Returns: + list[TextualMemoryItem]: List of memories with the specified IDs and collection name. + """ + try: + res = self.vector_db.get_by_ids(collection_name, memory_ids) + if not res: + return [] + return [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in res + ] + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError( + f"Memory with IDs {memory_ids} not found in collection {collection_name}: {e}" + ) from e + + def get_all(self) -> list[TextualMemoryItem]: + """Get all memories. + Returns: + list[TextualMemoryItem]: List of all memories. + """ + all_collections = self.vector_db.list_collections() + all_memories = {} + for collection_name in all_collections: + items = self.vector_db.get_all(collection_name) + all_memories[collection_name] = [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in items + ] + return all_memories + + def delete(self, memory_ids: list[str]) -> None: + """Delete memories. + Args: + memory_ids (list[str]): List of memory IDs to delete. + """ + raise NotImplementedError + + def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None: + """Delete memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to delete the memory from. + memory_ids (list[str]): List of memory IDs to delete. + """ + self.vector_db.delete(collection_name, memory_ids) + + def delete_all(self) -> None: + """Delete all memories.""" + for collection_name in self.vector_db.config.collection_name: + self.vector_db.delete_collection(collection_name) + self.vector_db.create_collection() + + def drop( + self, + ) -> None: + """Drop all databases.""" + raise NotImplementedError diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py new file mode 100644 index 000000000..29f30d384 --- /dev/null +++ b/src/memos/memories/textual/simple_preference.py @@ -0,0 +1,156 @@ +from typing import Any + +from memos.embedders.factory import ( + ArkEmbedder, + OllamaEmbedder, + SenTranEmbedder, + UniversalAPIEmbedder, +) +from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM +from memos.log import get_logger +from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.preference import PreferenceTextMemory +from memos.types import MessageList +from memos.vec_dbs.factory import MilvusVecDB, QdrantVecDB + + +logger = get_logger(__name__) + + +class SimplePreferenceTextMemory(PreferenceTextMemory): + """Preference textual memory implementation for storing and retrieving memories.""" + + def __init__( + self, + extractor_llm: OpenAILLM | OllamaLLM | AzureLLM, + vector_db: MilvusVecDB | QdrantVecDB, + embedder: OllamaEmbedder | ArkEmbedder | SenTranEmbedder | UniversalAPIEmbedder, + reranker, + extractor, + adder, + retriever, + ): + """Initialize memory with the given configuration.""" + self.extractor_llm = extractor_llm + self.vector_db = vector_db + self.embedder = embedder + self.reranker = reranker + self.extractor = extractor + self.adder = adder + self.retriever = retriever + + def get_memory( + self, messages: list[MessageList], type: str, info: dict[str, Any] + ) -> list[TextualMemoryItem]: + """Get memory based on the messages. + Args: + messages (MessageList): The messages to get memory from. + type (str): The type of memory to get. + info (dict[str, Any]): The info to get memory. + """ + return self.extractor.extract(messages, type, info) + + def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: + """Search for memories based on a query. + Args: + query (str): The query to search for. + top_k (int): The number of top results to return. + info (dict): Leave a record of memory consumption. + Returns: + list[TextualMemoryItem]: List of matching memories. + """ + return self.retriever.retrieve(query, top_k, info) + + def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: + """Add memories. + + Args: + memories: List of TextualMemoryItem objects or dictionaries to add. + """ + return self.adder.add(memories) + + def get_with_collection_name( + self, collection_name: str, memory_id: str + ) -> TextualMemoryItem | None: + """Get a memory by its ID and collection name. + Args: + memory_id (str): The ID of the memory to retrieve. + collection_name (str): The name of the collection to retrieve the memory from. + Returns: + TextualMemoryItem: The memory with the given ID and collection name. + """ + try: + res = self.vector_db.get_by_id(collection_name, memory_id) + if res is None: + return None + return TextualMemoryItem( + id=res.id, + memory=res.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**res.payload), + ) + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError( + f"Memory with ID {memory_id} not found in collection {collection_name}: {e}" + ) from e + + def get_by_ids_with_collection_name( + self, collection_name: str, memory_ids: list[str] + ) -> list[TextualMemoryItem]: + """Get memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to retrieve the memory from. + memory_ids (list[str]): List of memory IDs to retrieve. + Returns: + list[TextualMemoryItem]: List of memories with the specified IDs and collection name. + """ + try: + res = self.vector_db.get_by_ids(collection_name, memory_ids) + if not res: + return [] + return [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in res + ] + except Exception as e: + # Convert any other exception to ValueError for consistent error handling + raise ValueError( + f"Memory with IDs {memory_ids} not found in collection {collection_name}: {e}" + ) from e + + def get_all(self) -> list[TextualMemoryItem]: + """Get all memories. + Returns: + list[TextualMemoryItem]: List of all memories. + """ + all_collections = self.vector_db.list_collections() + all_memories = {} + for collection_name in all_collections: + items = self.vector_db.get_all(collection_name) + all_memories[collection_name] = [ + TextualMemoryItem( + id=memo.id, + memory=memo.payload.get("dialog_str", ""), + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in items + ] + return all_memories + + def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None: + """Delete memories by their IDs and collection name. + Args: + collection_name (str): The name of the collection to delete the memory from. + memory_ids (list[str]): List of memory IDs to delete. + """ + self.vector_db.delete(collection_name, memory_ids) + + def delete_all(self) -> None: + """Delete all memories.""" + for collection_name in self.vector_db.config.collection_name: + self.vector_db.delete_collection(collection_name) + self.vector_db.create_collection() diff --git a/src/memos/templates/instruction_completion.py b/src/memos/templates/instruction_completion.py new file mode 100644 index 000000000..7ad0fe190 --- /dev/null +++ b/src/memos/templates/instruction_completion.py @@ -0,0 +1,43 @@ +from typing import Any + +from memos.templates.prefer_complete_prompt import PREF_INSTRUCTIONS + + +def instruct_completion( + memories: list[dict[str, Any]] | None = None, +) -> str: + """Create instruction following the preferences.""" + explicit_pref = [] + implicit_pref = [] + for memory in memories: + pref_type = memory.get("metadata", {}).get("preference_type") + if pref_type == "explicit_preference": + pref = memory.get("metadata", {}).get("explicit_preference", None) + if pref: + explicit_pref.append(pref) + elif pref_type == "implicit_preference": + pref = memory.get("metadata", {}).get("implicit_preference", None) + if pref: + implicit_pref.append(pref) + + explicit_pref_str = ( + "Explicit Preference:\n" + + "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(explicit_pref)) + if explicit_pref + else "" + ) + implicit_pref_str = ( + "Implicit Preference:\n" + + "\n".join(f"{i + 1}. {pref}" for i, pref in enumerate(implicit_pref)) + if implicit_pref + else "" + ) + + if not explicit_pref_str and not implicit_pref_str: + return "" + if not explicit_pref_str: + return implicit_pref_str + "\n" + PREF_INSTRUCTIONS.replace("explicit preferences > ", "") + if not implicit_pref_str: + return explicit_pref_str + "\n" + PREF_INSTRUCTIONS.replace("implicit preferences > ", "") + + return explicit_pref_str + "\n" + implicit_pref_str + "\n" + PREF_INSTRUCTIONS diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py new file mode 100644 index 000000000..d40b7b778 --- /dev/null +++ b/src/memos/templates/prefer_complete_prompt.py @@ -0,0 +1,250 @@ +NAIVE_EXPLICIT_PREFERENCE_EXTRACT_PROMPT = """ +You are a preference extraction assistant. +Please extract the user's explicitly mentioned preferences from the following conversation. + +Notes: +- A preference means the user's explicit attitude or choice toward something. It is not limited to words like "like/dislike/want/don't want/prefer". +- This includes, but is not limited to, any user's explicitly expressed inclination, desire, rejection, or priority that counts as an explicit preference. +- Focus on extracting the user's preferences in query. Do not extract preferences from the assistant's responses unless the user explicitly agrees with or endorses the assistant's suggestions. +- When the user modifies or updates their preferences for the same topic or event, extract the complete evolution process of their preference changes, including both the original and updated preferences. + +Requirements: +1. Keep only the preferences explicitly mentioned by the user. Do not infer or assume. +2. Output should be a list of concise natural language summaries and the corresponding context summary, context summary must contain complete information of the conversation fragment that the preference is mentioned. +3. If multiple preferences are mentioned within the same topic, you need to merge the preferences and context summary. + +Conversation: +{qa_pair} + +Find ALL explicit preferences. If no explicit preferences found, return []. Output JSON only: +```json +[ + { + "explicit_preference": "A short natural language summary of the preferences", + "context_summary": "The corresponding context summary, which is a summary of the corresponding conversation, do not lack any scenario information", + "reasoning": "reasoning process to find the explicit preferences" + }, +] +``` +""" + + +NAIVE_IMPLICIT_PREFERENCE_EXTRACT_PROMPT = """ +You are a preference inference assistant. Please extract **implicit preferences** from the following conversation +(preferences that the user did not explicitly state but can be reasonably inferred from context, behavior, frequency, comparisons, exclusions, or scenario choices). + +Notes: +- Implicit preferences refer to user inclinations or choices that are not directly expressed, but can be reasonably inferred from factual cues in the conversation. +- Do not treat explicitly stated preferences as implicit preferences; this prompt is only for inferring preferences that are not directly mentioned. + +Requirements: +1. Only make inferences when there is sufficient evidence in the conversation; avoid unsupported or far-fetched guesses. +2. Output a concise natural language statement; do not use lists, categories, or include the reasoning process. +3. Inferred implicit preferences must not conflict with explicit preferences. +4. For implicit_preference: only output the preference statement itself; do not include any extra explanation, reasoning, or confidence information. Put all reasoning and explanation in the reasoning field. +5. If no implicit preference can be reasonably inferred, leave the implicit_preference field empty (do not output anything else). + +Conversation: +{qa_pair} + +Output format: +```json +{ + "implicit_preference": "A concise natural language statement of the implicit preferences reasonably inferred from the conversation, or an empty string", + "context_summary": "The corresponding context summary, which is a summary of the corresponding conversation, do not lack any scenario information", + "reasoning": "Briefly explain the reasoning process for the implicit preference" +} +``` +Don't output anything except the JSON. +""" + + +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT = """ +You are a content comparison expert. Now you are given old and new information, each containing a question, answer topic name and topic description. +Please judge whether these two information express the **same question or core content**, regardless of expression differences, details or example differences. The judgment criteria are as follows: + +- Core content is consistent, that is, the essence of the question, goal or core concept to be solved is the same, it counts as "same". +- Different expressions, different examples, but the core meaning is consistent, also counts as "same". +- If the question goals, concepts involved or solution ideas are different, it counts as "different". + +Please output JSON format: +{ + "is_same": true/false, + "reasoning": "Briefly explain the judgment basis, highlighting whether the core content is consistent" +} + +**Old Information:** +{old_information} + +**New Information:** +{new_information} +""" + + +NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE = """ +# User Preference Memory Management Agent + +You are a **User Preference Memory Management Agent**. +Your goal is to maintain a user's long-term **preference memory base** by analyzing new preference information and determining how it should update existing memories. + +Each memory entry contains three fields: +- **id**: a unique identifier for the memory. +- **context_summary**: a factual summary of the dialogue or situation from which the preference was extracted. +- **preference**: the extracted statement describing the user's preference or tendency. + +When updating a preference, you should also integrate and update the corresponding `context_summary` to ensure both fields stay semantically consistent. + +You must produce a complete **operation trace**, showing which memory entries (identified by unique IDs) should be **added**, **updated**, or **deleted**, and then output the **final memory state** after all operations. + +## Input Format + +New preference memory (new_memory): +{new_memory} + +Retrieved preference memories (retrieved_memories): +{retrieved_memories} + +## Task Instructions + +1. Analyze each retrieved memory and determine its relationship to the new memory: + - **Unrelated** → perform `"ADD"` (insert as a new independent memory); + - **Related** → perform `"UPDATE"` (refine, supplement, or merge both the `preference` and the `context_summary`); + - **Conflicting or outdated** → perform `"DELETE"` (remove obsolete or contradictory memory). + +2. If multiple retrieved memories describe the same preference theme, merge them into one updated memory entry, combining both their `preference` information and their `context_summary` in a coherent and concise way. + +3. Output a structured list of **operation traces**, each explicitly stating: + - which memory (by ID) is affected, + - what operation is performed, + - the before/after `preference` and `context_summary`, + - and the reasoning behind it. + +4. Output the **final memory state (after_update_state)**, representing the complete preference memory base after applying all operations. + +## Output Format (JSON) + +{ + "trace": [ + { + "op_id": "op_1", + "type": "ADD" | "UPDATE" | "DELETE", + "target_id": "(the old memory ID; null if ADD)", + "old_preference": "(the old preference text; null if ADD)", + "old_context_summary": "(the old context summary; null if ADD)", + "new_preference": "(the updated or newly created preference, if applicable)", + "new_context_summary": "(the updated or newly created context summary, if applicable)", + "reason": "(brief natural-language explanation for the decision)" + } + ], + "after_update_state": [ + { + "id": "id1", + "context_summary": "updated factual summary of the context", + "preference": "updated or final preference text" + } + ] +} + +## Example + +**Input:** +new_memory: +{ + "context_summary": "During a recent chat about study habits, the user mentioned that he often studies in quiet coffee shops and has started preferring lattes over Americanos, which he only drinks occasionally.", + "preference": "User now prefers lattes but occasionally drinks Americanos; he also enjoys studying in quiet coffee shops." +} + +retrieved_memories: +[ + { + "id": "id1", + "context_summary": "The user previously said he likes coffee in general.", + "preference": "User likes coffee." + }, + { + "id": "id2", + "context_summary": "The user once mentioned preferring Americanos during work breaks.", + "preference": "User prefers Americanos." + }, + { + "id": "id3", + "context_summary": "The user said he often works from home", + "preference": "User likes working from home." + }, + { + "id": "id4", + "context_summary": "The user noted he doesn't drink tea very often.", + "preference": "User has no particular interest in tea." + } +] + +**Output:** +{ + "trace": [ + { + "op_id": "op_1", + "type": "UPDATE", + "target_id": "id1", + "old_preference": "User likes coffee.", + "old_context_summary": "The user previously said he likes coffee in general.", + "new_preference": "User likes coffee, especially lattes, but occasionally drinks Americanos.", + "new_context_summary": "The user discussed his coffee habits, stating he now prefers lattes but only occasionally drinks Americanos", + "reason": "The new memory refines and expands the coffee preference and context while preserving frequency semantics ('occasionally')." + }, + { + "op_id": "op_2", + "type": "DELETE", + "target_id": "id2", + "old_preference": "User prefers Americanos.", + "old_context_summary": "The user once mentioned preferring Americanos during work breaks.", + "new_preference": null, + "new_context_summary": null, + "reason": "This old memory is now merged into the updated coffee preference (id1)." + }, + { + "op_id": "op_3", + "type": "UPDATE", + "target_id": "id3", + "old_preference": "User likes working from home.", + "old_context_summary": "The user said he often works from home.", + "new_preference": "User now prefers studying in quiet coffee shops instead of working from home.", + "new_context_summary": "The user mentioned shifting from working at home to studying in quiet cafes, reflecting a new preferred environment.", + "reason": "The preference has changed for the working environment." + } + ], + "after_update_state": [ + { + "id": "id1", + "context_summary": "The user discussed his coffee habits, saying he now prefers lattes but only occasionally drinks Americanos.", + "preference": "User likes coffee, especially lattes, but occasionally drinks Americanos." + }, + { + "id": "id3", + "context_summary": "The user mentioned shifting from working at home to studying in quiet cafes, reflecting a new preferred environment.", + "preference": "User now prefers studying in quiet coffee shops instead of working from home." + }, + { + "id": "id4", + "context_summary": "The user noted he doesn't drink tea very often.", + "preference": "User has no particular interest in tea." + } + ] +} + +## Output Requirements + +- The output **must** be valid JSON. +- Each operation must include both `preference` and `context_summary` updates where applicable. +- Each operation must include a clear `reason`. +- Multiple retrieved memories may be merged into one unified updated memory. +- `after_update_state` must reflect the final, post-update state of the preference memory base. +- Do **not** include any explanatory text outside the JSON. +""" + + +PREF_INSTRUCTIONS = """ +# Note: +Plaintext memory are summaries of facts, while preference memories are summaries of user preferences. +Your response must not violate any of the user's preferences, whether explicit or implicit, and briefly explain why you answer this way to avoid conflicts. +When encountering preference conflicts, the priority is: explicit preference > implicit preference > plaintext memory. +""" diff --git a/src/memos/vec_dbs/factory.py b/src/memos/vec_dbs/factory.py index 8df22d14d..f2950b4ea 100644 --- a/src/memos/vec_dbs/factory.py +++ b/src/memos/vec_dbs/factory.py @@ -2,6 +2,7 @@ from memos.configs.vec_db import VectorDBConfigFactory from memos.vec_dbs.base import BaseVecDB +from memos.vec_dbs.milvus import MilvusVecDB from memos.vec_dbs.qdrant import QdrantVecDB @@ -10,6 +11,7 @@ class VecDBFactory(BaseVecDB): backend_to_class: ClassVar[dict[str, Any]] = { "qdrant": QdrantVecDB, + "milvus": MilvusVecDB, } @classmethod diff --git a/src/memos/vec_dbs/item.py b/src/memos/vec_dbs/item.py index 6f74879ac..081400f15 100644 --- a/src/memos/vec_dbs/item.py +++ b/src/memos/vec_dbs/item.py @@ -41,3 +41,9 @@ def from_dict(cls, data: dict[str, Any]) -> "VecDBItem": def to_dict(self) -> dict[str, Any]: """Convert to dictionary format.""" return self.model_dump(exclude_none=True) + + +class MilvusVecDBItem(VecDBItem): + """Represents a single item in the Milvus vector database.""" + + memory: str | None = Field(default=None, description="Memory string") diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 7bb1ceeba..fb19fd6ff 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -4,7 +4,7 @@ from memos.dependency import require_python_package from memos.log import get_logger from memos.vec_dbs.base import BaseVecDB -from memos.vec_dbs.item import VecDBItem +from memos.vec_dbs.item import MilvusVecDBItem logger = get_logger(__name__) @@ -40,6 +40,7 @@ def create_schema(self): schema.add_field( field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True ) + schema.add_field(field_name="memory", datatype=DataType.VARCHAR, max_length=65535) schema.add_field( field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimension ) @@ -107,7 +108,7 @@ def search( collection_name: str, top_k: int, filter: dict[str, Any] | None = None, - ) -> list[VecDBItem]: + ) -> list[MilvusVecDBItem]: """ Search for similar items in the database. @@ -136,8 +137,9 @@ def search( entity = hit.get("entity", {}) items.append( - VecDBItem( + MilvusVecDBItem( id=str(hit["id"]), + memory=entity.get("memory"), vector=entity.get("vector"), payload=entity.get("payload", {}), score=1 - float(hit["distance"]), @@ -178,7 +180,7 @@ def _get_metric_type(self) -> str: } return metric_map.get(self.config.distance_metric, "L2") - def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None: + def get_by_id(self, collection_name: str, id: str) -> MilvusVecDBItem | None: """Get a single item by ID.""" results = self.client.get( collection_name=collection_name, @@ -191,13 +193,14 @@ def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None: entity = results[0] payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} - return VecDBItem( + return MilvusVecDBItem( id=entity["id"], + memory=entity.get("memory"), vector=entity.get("vector"), payload=payload, ) - def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: + def get_by_ids(self, collection_name: str, ids: list[str]) -> list[MilvusVecDBItem]: """Get multiple items by their IDs.""" results = self.client.get( collection_name=collection_name, @@ -211,8 +214,9 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: for entity in results: payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]} items.append( - VecDBItem( + MilvusVecDBItem( id=entity["id"], + memory=entity.get("memory"), vector=entity.get("vector"), payload=payload, ) @@ -222,7 +226,7 @@ def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]: def get_by_filter( self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100 - ) -> list[VecDBItem]: + ) -> list[MilvusVecDBItem]: """ Retrieve all items that match the given filter criteria using query_iterator. @@ -252,13 +256,14 @@ def get_by_filter( if not batch_results: break - # Convert batch results to VecDBItem objects + # Convert batch results to MilvusVecDBItem objects for entity in batch_results: # Extract the actual payload from Milvus entity payload = entity.get("payload", {}) all_items.append( - VecDBItem( + MilvusVecDBItem( id=entity["id"], + memory=entity.get("memory"), vector=entity.get("vector"), payload=payload, ) @@ -274,7 +279,7 @@ def get_by_filter( logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.") return all_items - def get_all(self, collection_name: str, scroll_limit=100) -> list[VecDBItem]: + def get_all(self, collection_name: str, scroll_limit=100) -> list[MilvusVecDBItem]: """Retrieve all items in the vector database.""" return self.get_by_filter(collection_name, {}, scroll_limit=scroll_limit) @@ -295,13 +300,14 @@ def count(self, collection_name: str, filter: dict[str, Any] | None = None) -> i # Extract row count from stats - stats is a dict, not a list return int(stats.get("row_count", 0)) - def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + def add(self, collection_name: str, data: list[MilvusVecDBItem | dict[str, Any]]) -> None: """ Add data to the vector database. Args: - data: List of VecDBItem objects or dictionaries containing: + data: List of MilvusVecDBItem objects or dictionaries containing: - 'id': unique identifier + - 'memory': memory string - 'vector': embedding vector - 'payload': additional fields for filtering/retrieval """ @@ -309,11 +315,12 @@ def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> N for item in data: if isinstance(item, dict): item = item.copy() - item = VecDBItem.from_dict(item) + item = MilvusVecDBItem.from_dict(item) # Prepare entity data entity = { "id": item.id, + "memory": item.memory, "vector": item.vector, "payload": item.payload if item.payload else {}, } @@ -326,11 +333,15 @@ def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> N data=entities, ) - def update(self, collection_name: str, id: str, data: VecDBItem | dict[str, Any]) -> None: + def update(self, collection_name: str, id: str, data: MilvusVecDBItem | dict[str, Any]) -> None: """Update an item in the vector database.""" + if id != data.id: + raise ValueError( + f"The id of the data to update must be the same as the id of the item to update, ID mismatch: expected {id}, got {data.id}" + ) if isinstance(data, dict): data = data.copy() - data = VecDBItem.from_dict(data) + data = MilvusVecDBItem.from_dict(data) # Use upsert for updates self.upsert(collection_name, [data]) @@ -347,7 +358,7 @@ def ensure_payload_indexes(self, fields: list[str]) -> None: # Field indexes are created automatically for scalar fields logger.info(f"Milvus automatically indexes scalar fields: {fields}") - def upsert(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None: + def upsert(self, collection_name: str, data: list[MilvusVecDBItem | dict[str, Any]]) -> None: """ Add or update data in the vector database. diff --git a/tests/configs/test_mem_cube.py b/tests/configs/test_mem_cube.py index 6c962dd01..c50195558 100644 --- a/tests/configs/test_mem_cube.py +++ b/tests/configs/test_mem_cube.py @@ -28,7 +28,7 @@ def test_base_mem_cube_config(): def test_general_mem_cube_config(): check_config_base_class( GeneralMemCubeConfig, - factory_fields=["text_mem", "act_mem", "para_mem"], + factory_fields=["text_mem", "act_mem", "para_mem", "pref_mem"], required_fields=[], optional_fields=["config_filename", "user_id", "cube_id"], reserved_fields=["model_schema"],