diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 558ddc5a8..d05512566 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,5 @@ + + repos: - repo: https://github.com/pycqa/flake8.git rev: 4.0.0 @@ -7,7 +9,8 @@ repos: (?x)^( thirdparty/| examples/| - tests/run.py + tests/run.py| + ms_agent/utils/prompts.py )$ - repo: https://github.com/PyCQA/isort.git rev: 4.3.21 diff --git a/ms_agent/agent/agent.yaml b/ms_agent/agent/agent.yaml index e29cf3f34..85efe4248 100644 --- a/ms_agent/agent/agent.yaml +++ b/ms_agent/agent/agent.yaml @@ -1,6 +1,6 @@ llm: service: modelscope - model: Qwen/Qwen3-235B-A22B + model: Qwen/Qwen3-235B-A22B-Instruct-2507 modelscope_api_key: modelscope_base_url: https://api-inference.modelscope.cn/v1 @@ -31,20 +31,22 @@ prompt: 6. Do not call tools carelessly. Show your thoughts **as detailed as possible**. - For requests that require performing a specific task or retrieving information, you must use the following format: + 7. Respond in the same language the user uses. If the user switches, switch accordingly. + + For requests that require performing a specific task or retrieving information, you must use the following format in user language: ``` - 用户需要 ... - 针对该需求,我进行了详细拆解和规划,需要按照如下步骤来解决问题: + The user needs to ... + I have analyzed this request in detail and broken it down into the following steps: ... ``` If you have tools which may help you to solve problems, follow this format to answer: ``` - 用户需要 ... - 针对该需求,我进行了详细拆解和规划,需要按照如下步骤来解决问题: + The user needs to ... + I have analyzed this request in detail and broken it down into the following steps: ... - 首先我应当选择...工具,由于该工具..., 该工具的入参需要... + First, I should use the [Tool Name] because [explain relevance]. The required input parameters are: ... ... - 我仔细查看了工具返回值,该工具的返回值符合/不符合我的要求,我接下来需要... + I have carefully reviewed the tool's output. The result does/does not fully meet my expectations. Next, I need to ... ``` max_chat_round: 9999 diff --git a/ms_agent/agent/base.py b/ms_agent/agent/base.py index 421d30be4..8541a40ed 100644 --- a/ms_agent/agent/base.py +++ b/ms_agent/agent/base.py @@ -9,7 +9,7 @@ from ms_agent.config import Config from ms_agent.config.config import ConfigLifecycleHandler from ms_agent.llm import Message -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf DEFAULT_YAML = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'agent.yaml') @@ -43,10 +43,10 @@ def __init__(self, trust_remote_code: bool = False): if config_dir_or_id is not None: self.config: DictConfig = Config.from_task(config_dir_or_id, env) - elif config is not None and isinstance(config, DictConfig): - self.config: DictConfig = config else: self.config: DictConfig = Config.from_task(DEFAULT_YAML) + if config is not None and isinstance(config, DictConfig): + self.config = OmegaConf.merge(self.config, config) if tag is None: self.tag = getattr(config, 'tag', None) or self.DEFAULT_TAG diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index 1c325b5bb..7e09eedd6 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -11,6 +11,8 @@ from ms_agent.callbacks import Callback, callbacks_mapping from ms_agent.llm.llm import LLM from ms_agent.llm.utils import Message, Tool +from ms_agent.memory import Memory, memory_mapping +from ms_agent.memory.mem0ai import Mem0Memory from ms_agent.rag.base import RAG from ms_agent.rag.utils import rag_mapping from ms_agent.tools import ToolManager @@ -20,7 +22,6 @@ from ..utils.utils import read_history, save_history from .base import Agent -from .memory import Memory, memory_mapping from .plan.base import Planer from .plan.utils import planer_mapping from .runtime import Runtime @@ -229,20 +230,52 @@ async def _prepare_messages( return messages async def _prepare_memory(self): - """Load and initialize memory components from the config.""" + """ + Prepare memory + + Initializes and appends memory tool instances based on the configuration provided in self.config. + Args: + self: The instance of the class containing this method. Expected to have: + - config: An object that may contain a memory attribute, which is a list of memory configurations. + + Returns: + None + + Raises: + AssertionError: If a specified memory type in the config does not exist in memory_mapping. + """ if hasattr(self.config, 'memory'): for _memory in (self.config.memory or []): - assert _memory.name in memory_mapping, ( - f'{_memory.name} not in memory_mapping, ' + memory_type = getattr(_memory, 'name', 'default_memory') + assert memory_type in memory_mapping, ( + f'{memory_type} not in memory_mapping, ' f'which supports: {list(memory_mapping.keys())}') - if _memory.name == 'mem0': - from .memory.mem0ai import SharedMemoryManager + + # Use LLM config if no special configuration is specified + llm_config = getattr(_memory, 'llm', None) + if llm_config is None: + service = self.config.llm.service + config_dict = { + 'model': + self.config.llm.model, + 'provider': + 'openai', + 'openai_base_url': + getattr(self.config.llm, f'{service}_base_url', None), + 'openai_api_key': + getattr(self.config.llm, f'{service}_api_key', None), + } + llm_config_obj = OmegaConf.create(config_dict) + setattr(_memory, 'llm', llm_config_obj) + + if memory_type == 'mem0': + from ms_agent.memory.mem0ai import SharedMemoryManager shared_memory = SharedMemoryManager.get_shared_memory( _memory) self.memory_tools.append(shared_memory) else: self.memory_tools.append( - memory_mapping[_memory.name](_memory)) + memory_mapping[memory_type](_memory)) async def _prepare_planer(self): """Load and initialize the planer component from the config.""" @@ -464,7 +497,9 @@ def _save_history(self, messages: List[Message], **kwargs): user_id = memory_config.user_id break for memory_tool in self.memory_tools: - memory_tool._add_memories_from_conversation(messages, user_id) + if isinstance(memory_tool, Mem0Memory): + memory_tool._add_memories_from_conversation( + messages, user_id) if not self.task or self.task == 'subtask': return @@ -487,8 +522,9 @@ def _save_memory(self, messages: List[Message], **kwargs): if self.memory_tools: agent_id = self.tag for memory_tool in self.memory_tools: - memory_tool._add_memories_from_procedural( - messages, 'subagent', agent_id, 'procedural_memory') + if isinstance(memory_tool, Mem0Memory): + memory_tool._add_memories_from_procedural( + messages, 'subagent', agent_id, 'procedural_memory') return async def _run(self, messages: Union[List[Message], str], diff --git a/ms_agent/agent/memory/__init__.py b/ms_agent/memory/__init__.py similarity index 60% rename from ms_agent/agent/memory/__init__.py rename to ms_agent/memory/__init__.py index 3352d6d42..cbbb992be 100644 --- a/ms_agent/agent/memory/__init__.py +++ b/ms_agent/memory/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .base import Memory -from .utils import memory_mapping +from .utils import DefaultMemory, memory_mapping diff --git a/ms_agent/agent/memory/base.py b/ms_agent/memory/base.py similarity index 88% rename from ms_agent/agent/memory/base.py rename to ms_agent/memory/base.py index b8ea7496c..409f1d483 100644 --- a/ms_agent/agent/memory/base.py +++ b/ms_agent/memory/base.py @@ -8,6 +8,9 @@ class Memory: """The memory refine tool""" + def __init__(self, config): + self.config = config + @abstractmethod async def run(self, messages: List[Message]) -> List[Message]: """Refine the messages diff --git a/ms_agent/memory/default_memory.py b/ms_agent/memory/default_memory.py new file mode 100644 index 000000000..71611ae17 --- /dev/null +++ b/ms_agent/memory/default_memory.py @@ -0,0 +1,442 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import hashlib +import os +from copy import deepcopy +from functools import partial, wraps +from typing import Any, Dict, List, Literal, Optional, Set, Tuple + +import json +import json5 +from ms_agent.llm.utils import Message +from ms_agent.memory import Memory +from ms_agent.utils import get_fact_retrieval_prompt +from ms_agent.utils.logger import logger +from omegaconf import DictConfig, OmegaConf + + +class MemoryMapping: + memory_id: str = None + memory: str = None + valid: bool = None + enable_idxs: List[int] = [] + disable_idx: int = -1 + + def __init__(self, memory_id: str, value: str, enable_idxs: int + or List[int]): + self.memory_id = memory_id + self.value = value + self.valid = True + if isinstance(enable_idxs, int): + enable_idxs = [enable_idxs] + self.enable_idxs = enable_idxs + + def udpate_idxs(self, enable_idxs: int or List[int]): + if isinstance(enable_idxs, int): + enable_idxs = [enable_idxs] + self.enable_idxs.extend(enable_idxs) + + def disable(self, disable_idx: int): + self.valid = False + self.disable_idx = disable_idx + + def try_enable(self, expired_disable_idx: int): + if expired_disable_idx == self.disable_idx: + self.valid = True + self.disable_idx = -1 + + def get(self): + return self.value + + def to_dict(self) -> Dict: + return { + 'memory_id': self.memory_id, + 'value': self.value, + 'valid': self.valid, + 'enable_idxs': self.enable_idxs.copy( + ), # Return a copy to prevent external modification + 'disable_idx': self.disable_idx + } + + @classmethod + def from_dict(cls, data: Dict) -> 'MemoryMapping': + instance = cls( + memory_id=data['memory_id'], + value=data['value'], + enable_idxs=data['enable_idxs']) + instance.valid = data['valid'] + instance.disable_idx = data.get('disable_idx', + -1) # Compatible with old data + return instance + + +class DefaultMemory(Memory): + """The memory refine tool""" + + def __init__(self, config: DictConfig): + super().__init__(config) + self.user_id: Optional[str] = getattr(self.config, 'user_id', None) + self.persist: Optional[bool] = getattr(config, 'persist', True) + self.compress: Optional[bool] = getattr(config, 'compress', True) + self.is_retrieve: Optional[bool] = getattr(config, 'is_retrieve', True) + self.path: Optional[str] = getattr(self.config, 'path', 'output') + self.history_mode = getattr(config, 'history_mode', 'add') + self.ignore_role: List[str] = getattr(config, 'ignore_role', + ['tool', 'system']) + self.ignore_fields: List[str] = getattr(config, 'ignore_fields', + ['reasoning_content']) + self.memory = self._init_memory_obj() + self.init_cache_messages() + + def init_cache_messages(self): + self.load_cache() + if len(self.cache_messages) and not len(self.memory_snapshot): + new_blocks = self._split_into_blocks(self.cache_messages) + for messages in new_blocks: + self.max_msg_id += 1 + self.add(messages, msg_id=self.max_msg_id) + + def save_cache(self): + """ + Save self.max_msg_id, self.cache_messages, and self.memory_snapshot to self.path/cache_messages.json + """ + cache_file = os.path.join(self.path, 'cache_messages.json') + + # Ensure the directory exists + os.makedirs(self.path, exist_ok=True) + + data = { + 'max_msg_id': self.max_msg_id, + 'cache_messages': { + str(k): ([msg.to_dict() for msg in msg_list], _hash) + for k, (msg_list, _hash) in self.cache_messages.items() + }, + 'memory_snapshot': [mm.to_dict() for mm in self.memory_snapshot] + } + + with open(cache_file, 'w', encoding='utf-8') as f: + json5.dump(data, f, indent=2, ensure_ascii=False) + + def load_cache(self): + """ + Load data from self.path/cache_messages.json into self.max_msg_id, self.cache_messages, and self.memory_snapshot + """ + cache_file = os.path.join(self.path, 'cache_messages.json') + + if not os.path.exists(cache_file): + # If the file does not exist, initialize default values and return. + self.max_msg_id = -1 + self.cache_messages = {} + self.memory_snapshot = [] + return + + try: + with open(cache_file, 'r', encoding='utf-8') as f: + data = json5.load(f) + + self.max_msg_id = data.get('max_msg_id', -1) + + # Parse cache_messages + cache_messages = {} + raw_cache_msgs = data.get('cache_messages', {}) + for k, (msg_list, timestamp) in raw_cache_msgs.items(): + msg_objs = [Message(**msg_dict) for msg_dict in msg_list] + cache_messages[int(k)] = (msg_objs, timestamp) + self.cache_messages = cache_messages + + # Parse memory_snapshot + self.memory_snapshot = [ + MemoryMapping.from_dict(d) + for d in data.get('memory_snapshot', []) + ] + + except (json.JSONDecodeError, KeyError, Exception) as e: + logger.warning(f'Failed to load cache: {e}') + # Fall back to default state when an error occurs + self.max_msg_id = -1 + self.cache_messages = {} + self.memory_snapshot = [] + + def delete_single(self, msg_id: int): + messages_to_delete = self.cache_messages.get(msg_id, None) + if messages_to_delete is None: + return + self.cache_messages.pop(msg_id, None) + if msg_id == self.max_msg_id: + self.max_msg_id = max(self.cache_messages.keys()) + + idx = 0 + while idx < len(self.memory_snapshot): + + enable_ids = self.memory_snapshot[idx].enable_idxs + disable_id = self.memory_snapshot[idx].disable_idx + if msg_id == disable_id: + self.memory_snapshot[idx].try_enable(msg_id) + self.memory._create_memory( + data=self.memory_snapshot[idx].value, + existing_embeddings={}, + metadata={'user_id': self.user_id}) + if msg_id in enable_ids: + if len(enable_ids) > 1: + self.memory_snapshot[idx].enable_idxs.remove(msg_id) + else: + self.memory.delete(self.memory_snapshot[idx].memory_id) + self.memory_snapshot.pop(idx) + idx -= 1 # After pop, the next item becomes the current idx + + idx += 1 + + def add(self, messages: List[Message], msg_id: int) -> None: + self.cache_messages[msg_id] = messages, self._hash_block(messages) + + messages_dict = [] + for message in messages: + if isinstance(message, Message): + messages_dict.append(message.to_dict()) + else: + messages_dict.append(message) + self.memory.add(messages_dict, user_id=self.user_id) + + self.max_msg_id = max(self.max_msg_id, msg_id) + res = self.memory.get_all(user_id=self.user_id) # sorted + res = [(item['id'], item['memory']) for item in res['results']] + if len(res): + logger.info('Add memory success. All memory info:') + for item in res: + logger.info(item[1]) + valids = [] + unmatched = [] + for id, memory in res: + matched = False + for item in self.memory_snapshot: + if id == item.memory_id: + if item.value == memory and item.valid: + matched = True + valids.append(id) + break + else: + if item.valid: + item.disable(msg_id) + if not matched: + unmatched.append((id, memory)) + for item in self.memory_snapshot: + if item.memory_id not in valids: + item.disable(msg_id) + for (id, memory) in unmatched: + m = MemoryMapping(memory_id=id, value=memory, enable_idxs=msg_id) + self.memory_snapshot.append(m) + + def search(self, query: str) -> str: + relevant_memories = self.memory.search( + query, user_id=self.user_id, limit=3) + memories_str = '\n'.join(f"- {entry['memory']}" + for entry in relevant_memories['results']) + return memories_str + + def _split_into_blocks(self, + messages: List[Message]) -> List[List[Message]]: + """ + Split messages into blocks where each block starts with a 'user' message + and includes all following non-user messages until the next 'user' (exclusive). + + The very first messages before the first 'user' (e.g., system) are attached to the first user block. + If no user message exists, all messages go into one block. + """ + if not messages: + return [] + + blocks: List[List[Message]] = [] + current_block: List[Message] = [] + + # Handle leading non-user messages (like system) + have_user = False + for msg in messages: + if msg.role != 'user': + current_block.append(msg) + else: + if have_user: + blocks.append(current_block) + current_block = [msg] + else: + current_block.append(msg) + have_user = True + + # Append the last block + if current_block: + blocks.append(current_block) + + return blocks + + def _hash_block(self, block: List[Message]) -> str: + """Compute sha256 hash of a message block for comparison""" + data = [message.to_dict() for message in block] + allow_role = ['user', 'system', 'assistant', 'tool'] + allow_role = [ + role for role in allow_role if role not in self.ignore_role + ] + allow_fields = ['reasoning_content', 'content', 'tool_calls', 'role'] + allow_fields = [ + field for field in allow_fields if field not in self.ignore_fields + ] + + data = [{ + field: value + for field, value in msg.items() if field in allow_fields + } for msg in data if msg['role'] in allow_role] + + block_data = json5.dumps(data) + return hashlib.sha256(block_data.encode('utf-8')).hexdigest() + + def _analyze_messages( + self, + messages: List[Message]) -> Tuple[List[List[Message]], List[int]]: + """ + Analyze incoming messages against cache. + + Returns: + should_add_messages: blocks to add (not in cache or hash changed) + should_delete: list of msg_id to delete (in cache but not in new blocks) + """ + new_blocks = self._split_into_blocks(messages) + self.cache_messages = dict(sorted(self.cache_messages.items())) + + cache_messages = [(key, value) + for key, value in self.cache_messages.items()] + first_unmatched_idx = -1 + for idx in range(len(new_blocks)): + block_hash = self._hash_block(new_blocks[idx]) + if idx < len(cache_messages) - 1 and str(block_hash) == str( + cache_messages[idx][1][1]): + continue + first_unmatched_idx = idx + break + should_delete = [ + item[0] for item in cache_messages[first_unmatched_idx:] + ] if first_unmatched_idx != -1 else [] + should_add_messages = new_blocks[first_unmatched_idx:] + + return should_add_messages, should_delete + + def _get_user_message(self, block: List[Message]) -> Optional[Message]: + """Helper: get the user message from a block, if exists""" + for msg in block: + if msg.role == 'user': + return msg + return None + + def _should_update_memory(self, messages: List[Message]) -> bool: + # TODO: Avoid unnecessary frequent updates and reduce the number of update operations + return True + + async def run(self, messages, ignore_role=None, ignore_fields=None): + if not self.is_retrieve or not self._should_update_memory(messages): + return messages + should_add_messages, should_delete = self._analyze_messages(messages) + + if should_delete: + if self.history_mode == 'overwrite': + for msg_id in should_delete: + self.delete_single(msg_id=msg_id) + res = self.memory.get_all(user_id=self.user_id) # sorted + res = [(item['id'], item['memory']) for item in res['results']] + logger.info('Roll back success. All memory info:') + for item in res: + logger.info(item[1]) + if should_add_messages: + for messages in should_add_messages: + self.max_msg_id += 1 + self.add(messages, msg_id=self.max_msg_id) + self.save_cache() + + query = getattr(messages[-1], 'content') + memories_str = self.search(query) + # Remove the messages section corresponding to memory, and add the related memory_str information + remain_idx = len(messages) - sum( + [len(block) for block in should_add_messages]) + if getattr(messages[0], 'role') == 'system': + system_prompt = getattr( + messages[0], 'content') + f'\nUser Memories: {memories_str}' + if remain_idx < 1: + remain_idx = 1 + else: + system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\n' \ + f'User Memories: {memories_str}' + + new_messages = [Message(role='system', content=system_prompt) + ] + messages[remain_idx:] + return new_messages + + def _init_memory_obj(self): + import mem0 + parse_messages_origin = mem0.memory.main.parse_messages + + @wraps(parse_messages_origin) + def patched_parse_messages(messages, ignore_role): + response = '' + for msg in messages: + if 'system' not in ignore_role and msg['role'] == 'system': + response += f"system: {msg['content']}\n" + if msg['role'] == 'user': + response += f"user: {msg['content']}\n" + if msg['role'] == 'assistant' and msg['content'] is not None: + response += f"assistant: {msg['content']}\n" + if 'tool' not in ignore_role and msg['role'] == 'tool': + response += f"tool: {msg['content']}\n" + return response + + patched_func = partial( + patched_parse_messages, + ignore_role=self.ignore_role, + ) + + mem0.memory.main.parse_messages = patched_func + + if not self.is_retrieve: + return + + embedder: Optional[str] = getattr( + self.config, 'embedder', + OmegaConf.create({ + 'provider': 'openai', + 'config': { + 'api_key': os.getenv('DASHSCOPE_API_KEY'), + 'openai_base_url': + 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'model': 'text-embedding-v4', + } + })) + + llm = {} + if self.compress: + llm_config = getattr(self.config, 'llm', None) + # follow mem0 config + model = llm_config.get('model') + provider = llm_config.get('provider', 'openai') + openai_base_url = llm_config.get('openai_base_url', None) + openai_api_key = llm_config.get('openai_api_key', None) + llm = { + 'provider': provider, + 'config': { + 'model': model, + 'openai_base_url': openai_base_url, + 'api_key': openai_api_key + } + } + + mem0_config = { + 'is_infer': self.compress, + 'llm': llm, + 'vector_store': { + 'provider': 'qdrant', + 'config': { + 'path': self.path, + 'on_disk': self.persist + } + }, + 'embedder': embedder + } + logger.info(f'Memory config: {mem0_config}') + # Prompt content is too long, default logging reduces readability + mem0_config['custom_fact_extraction_prompt'] = getattr( + self.config, 'fact_retrieval_prompt', get_fact_retrieval_prompt()) + memory = mem0.Memory.from_config(mem0_config) + return memory diff --git a/ms_agent/agent/memory/mem0ai.py b/ms_agent/memory/mem0ai.py similarity index 98% rename from ms_agent/agent/memory/mem0ai.py rename to ms_agent/memory/mem0ai.py index 03e7fd628..77fb30d9b 100644 --- a/ms_agent/agent/memory/mem0ai.py +++ b/ms_agent/memory/mem0ai.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional from ms_agent.llm.utils import Message -from ms_agent.utils import get_fact_retrieval_prompt, get_logger +from ms_agent.utils import get_code_fact_retrieval_prompt, get_logger from omegaconf import DictConfig from .base import Memory @@ -104,7 +104,7 @@ def _initialize_memory(self): # Monkey patch Mem0's parse_messages function to handle tool messages mem0.memory.main.parse_messages = self.patched_parse_messages # Also update the imported reference in utils module - mem0.memory.utils.FACT_RETRIEVAL_PROMPT = get_fact_retrieval_prompt( + mem0.memory.utils.FACT_RETRIEVAL_PROMPT = get_code_fact_retrieval_prompt( ) embedding_model = 'text-embedding-3-small' diff --git a/ms_agent/agent/memory/utils.py b/ms_agent/memory/utils.py similarity index 61% rename from ms_agent/agent/memory/utils.py rename to ms_agent/memory/utils.py index 1f7b3cb89..ea00f78ad 100644 --- a/ms_agent/agent/memory/utils.py +++ b/ms_agent/memory/utils.py @@ -1,6 +1,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .default_memory import DefaultMemory from .mem0ai import Mem0Memory memory_mapping = { + 'default_memory': DefaultMemory, 'mem0': Mem0Memory, } diff --git a/ms_agent/utils/__init__.py b/ms_agent/utils/__init__.py index 54122de62..f93f20ce4 100644 --- a/ms_agent/utils/__init__.py +++ b/ms_agent/utils/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .llm_utils import async_retry, retry from .logger import get_logger -from .prompt import get_fact_retrieval_prompt +from .prompt import get_code_fact_retrieval_prompt, get_fact_retrieval_prompt from .utils import assert_package_exist, enhance_error, strtobool MAX_CONTINUE_RUNS = 3 diff --git a/ms_agent/utils/prompt.py b/ms_agent/utils/prompt.py index 01e027f5d..b573539c6 100644 --- a/ms_agent/utils/prompt.py +++ b/ms_agent/utils/prompt.py @@ -5,6 +5,83 @@ def get_fact_retrieval_prompt(): + return f"""You are a Personal Information Organizer, specialized in accurately storing facts, user memories, preferences, and processing tool interaction outcomes. Your primary role is to extract relevant pieces of information from conversations, organize them into distinct, manageable facts, and additionally process and summarize tool invocation results when present. This ensures both personal data and system interactions are captured for improved context retention and future personalization. + +Types of Information to Remember: +1. Store Personal Preferences: Keep track of likes, dislikes, and specific preferences in various categories such as food, products, activities, and entertainment. +2. Maintain Important Personal Details: Remember significant personal information like names, relationships, and important dates. +3. Track Plans and Intentions: Note upcoming events, trips, goals, and any plans the user has shared. +4. Remember Activity and Service Preferences: Recall preferences for dining, travel, hobbies, and other services. +5. Monitor Health and Wellness Preferences: Keep a record of dietary restrictions, fitness routines, and other wellness-related information. +6. Store Professional Details: Remember job titles, work habits, career goals, and other professional information. +7. Miscellaneous Information Management: Keep track of favorite books, movies, brands, and other miscellaneous details that the user shares. + +Tool Interaction Processing Instructions (Additional Responsibilities): +When tool calls and their results are included in the conversation, perform the following in addition to fact extraction: + +1. Extract and Organize Factual Information from Tool Outputs: + - Parse the returned data from successful tool calls (e.g., weather, calendar, search, maps). + - Identify and store objective, user-relevant facts derived from these results (e.g., "It will rain in Paris on 2025-08-25", "The restaurant Little Italy is located at 123 Main St"). + - Integrate these into the "facts" list only if they reflect new, meaningful information about the user's context or environment. +2. Analyze and Summarize Error-Prone Tools: + - Identify tools that frequently fail, time out, or return inconsistent results. + - For such tools, generate a brief internal summary noting the pattern of failure (e.g., "Search tool often returns incomplete results for restaurant queries"). + - This summary does not go into the JSON output but informs future handling (e.g., suggesting alternative tools or double-checking outputs). +3. Identify and Log Tools That Cannot Be Called: + - If a tool was intended but not invoked (e.g., due to missing permissions, unavailability, or misconfiguration), note this in a separate internal log. + - Examples: "Calendar tool unavailable — cannot retrieve user's meeting schedule", "Location access denied — weather tool cannot auto-detect city". + - Include a user-facing reminder if relevant: add a fact like "Could not access calendar due to permission restrictions" only if it impacts user understanding. +4. Ensure Clarity and Non-Disclosure: + - Do not expose tool names, system architecture, or internal logs in the output. + - If asked why information is missing, respond: "I tried to retrieve it from publicly available sources, but the information may not be accessible right now." + +Here are some few-shot examples: +Input: Hi. +Output: {{"facts" : []}} + +Input: There are branches in trees. +Output: {{"facts" : []}} + +Input: Hi, I am looking for a restaurant in San Francisco. +Output: {{"facts" : ["Looking for a restaurant in San Francisco"]}} + +Input: Yesterday, I had a meeting with John at 3pm. We discussed the new project. +Output: {{"facts" : ["Had a meeting with John at 3pm", "Discussed the new project"]}} + +Input: Hi, my name is John. I am a software engineer. +Output: {{"facts" : ["Name is John", "Is a Software engineer"]}} + +Input: My favourite movies are Inception and Interstellar. +Output: {{"facts" : ["Favourite movies are Inception and Interstellar"]}} + +Input (with tool call): What's the weather like in Tokyo today? +[Tool Call: get_weather(location="Tokyo", date="2025-08-22") → Result: {{"status": "success", "data": {{"temp": 32°C, "condition": "Sunny", "humidity": 65%}}}}] +Output: {{"facts": ["It is 32°C and sunny in Tokyo today", "Humidity level in Tokyo is 65%"]}} + +Input (with failed tool): Check my calendar for tomorrow's meetings. +[Tool Call: get_calendar(date="2025-08-23") → Failed: "Access denied – calendar not connected"] +Output: {{"facts": ["Could not access calendar due to connection issues"]}} + +Input (with unreliable tool pattern): Search for vegan restaurants near Central Park. +[Tool Call: search(query="vegan restaurants near Central Park") → Returns incomplete/no results multiple times] +Output: {{"facts": ["Searching for vegan restaurants near Central Park yielded limited results"]}} +(Internal note: Search tool shows low reliability for location-based queries — consider fallback sources.) + +Final Output Rules: + - Today's date is {datetime.now().strftime("%Y-%m-%d")}. + - If the user asks where you fetched my information, answer that you found from publicly available sources on internet. + - Return only a JSON object with key "facts" and value as a list of strings. + - Do not include anything from the example prompts or system instructions. + - Do not reveal tool usage, internal logs, or model behavior. + - If no relevant personal or environmental facts are found, return: {{"facts": []}} + - Extract facts only from user and assistant messages — ignore system-level instructions. + - Detect the input language and record facts in the same language. + +Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation, process any tool call results, and return them in the JSON format as shown above. +""" + + +def get_code_fact_retrieval_prompt(): return f"""You are a Code Development Information Organizer, specialized in accurately storing development facts, project details, and technical preferences from coding conversations. Your primary role is to extract relevant pieces of technical information that will be useful for future code generation and development tasks. Below are the types of information you need to focus on and the detailed instructions on how to handle the input data. Types of Information to Remember: diff --git a/tests/memory/test_default_memory.py b/tests/memory/test_default_memory.py new file mode 100644 index 000000000..410f37bc5 --- /dev/null +++ b/tests/memory/test_default_memory.py @@ -0,0 +1,229 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from ms_agent.agent import LLMAgent +from ms_agent.llm.utils import Message, ToolCall +from omegaconf import OmegaConf + +from modelscope.utils.test_utils import test_level + + +class TestDefaultMemory(unittest.TestCase): + + def setUp(self) -> None: + self.tool_history = [ + Message( + role='user', + content= + 'Help me find the coolest sports park in Chaoyang District, Beijing. Remember this location for next' + 'time.'), + Message( + role='assistant', + content= + '\nThe user wants to find the coolest sports park in Chaoyang District, Beijing. The keyword "cool" ' + 'suggests they are not just looking for functionality, but also value high-tech design, interactive ' + 'experiences, and trendy elements. Therefore, I need to search for venues that have a futuristic feel, ' + 'integrate advanced technology, youth culture, or innovative facilities.\n\nTo solve this, I will take ' + 'the following steps:\n' + '1. Use the awesome_map-search tool to search for sports parks in Chaoyang District, Beijing, focusing ' + 'on keywords like "cool", "high-tech", "trendy", etc.\n' + '2. Filter out the most distinctive, highly-rated parks with strong visual impact\n' + '3. Provide the name of the sports park.\n\nNow I will call the awesome_map-search tool. This tool is ' + 'designed specifically for geo + fun information retrieval, ' + 'supports semantic queries, and excels at discovering "hidden gem" locations.', + tool_calls=[ + ToolCall( + id='call_xyz789CoolPark', + type='function', + tool_name='awesome_map-search', + arguments= + '{"query": "Beijing Chaoyang District coolest sports park high-tech trendy must-visit spot", ' + '"max_results": 1, "result_type":' + '"detailed", "include_features": true}') + ]), + Message( + role='tool', + content= + '[{"park_name": "Quantum Leap · Chaoyang Future Sports Dome", "address": "No. 8888 Jiuxianqiao Road, ' + 'Chaoyang District, Beijing", "features": ["Anti-gravity floating track using magnetic levitation ' + 'cushioning tech — running feels like jumping on the moon", "AR extreme climbing wall supporting ' + 'virtual scenes such as volcanic apocalypse and interstellar maze", "DJ-powered cycling pods: the ' + 'faster you pedal, the more intense the music, with real-time energy waveform visualization", ' + '"AI coach robot \'Tie Ran\', which creates custom workout plans and motivates users in Northeastern ' + 'Chinese dialect", "Glow-in-the-dark Tai Chi lake with floating yoga platforms; water contains ' + 'eco-friendly bioluminescent algae glowing blue at night"], "special_events": ["Cool Sports Event ' + 'every Friday night: laser dodgeball, glow-in-the-dark dance run, robot tug-of-war"], "tips": ["Must ' + 'solve 3 fitness puzzles to unlock entry", "Bring a power bank — taking photos too frequently drains ' + 'battery quickly"], "rating": 9.9, "opening_hours": "6:00 - 23:00", "admission": "Free"}]' + ), + Message( + role='assistant', + content= + '\nThe user wanted to find the coolest sports park in Chaoyang District, Beijing. I have retrieved ' + 'relevant information via the awesome_map-search tool.\n\n## 🌟 The Coolest Sports Park in Chaoyang ' + 'District, Beijing Revealed: 「Quantum Leap · Chaoyang Future Sports Dome」' + ), + Message(role='user', content='Alright, bye!') + ] + + def tearDown(self): + import shutil + shutil.rmtree('output', ignore_errors=True) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_default_memory(self): + import uuid + import asyncio + + async def main(): + random_id = str(uuid.uuid4()) + default_memory = OmegaConf.create({ + 'memory': [{ + 'path': f'output/{random_id}', + 'user_id': random_id + }] + }) + agent1 = LLMAgent(config=default_memory) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run( + 'I am a vegetarian and I drink coffee every morning.') + del agent1 + print( + '========== Data preparation completed, starting test ===========' + ) + agent2 = LLMAgent(config=default_memory) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run( + 'Please help me plan tomorrow’s three meals.') + print(res) + assert ('vegetarian' in res[-1].content.lower() + or 'vegan' in res[-1].content.lower() + ) and 'coffee' in res[-1].content.lower() + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_agent_tool(self): + import uuid + import asyncio + + async def main(): + random_id = str(uuid.uuid4()) + config = OmegaConf.create({ + 'memory': [{ + 'ignore_role': ['system'], + 'user_id': random_id, + 'path': f'output/{random_id}' + }] + }) + agent1 = LLMAgent(config=OmegaConf.create(config)) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run(self.tool_history) + agent1.memory_tools[0].memory.vector_store.client.close() + del agent1 + print( + '========== Data preparation completed, starting test ===========' + ) + agent2 = LLMAgent(config=OmegaConf.create(config)) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run( + 'What is the location of the coolest sports park in Chaoyang District, Beijing?' + ) + print(res) + assert 'Jiuxianqiao Road 8888' in res[ + -1].content or 'No. 8888 Jiuxianqiao Road' in res[-1].content + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_overwrite_with_tool(self): + import uuid + import asyncio + + async def main(): + tool_history1 = self.tool_history[:-1] + [ + Message( + role='user', + content= + 'The sports park you mentioned has already closed down.'), + Message( + role='assistant', + content= + 'The user mentioned that "Quantum Leap · Chaoyang Future Sports Dome" has shut down. Today is ' + 'May 7, 2045. I need to search again for the currently operating coolest sports park. I will use ' + 'the awesome_map-search tool with updated time-sensitive keywords such as "open in 2045" to ensure ' + 'accuracy and timeliness.', + tool_calls=[ + ToolCall( + id='call_xyz2045NewPark', + type='function', + tool_name='awesome_map-search', + arguments= + '{"query": "Beijing Chaoyang District coolest sports park high-tech trendy must-visit spot ' + 'open in 2045", "max_results": 1, "result_type": "detailed", "include_features": true}' + ) + ]), + Message( + role='tool', + content= + '[{"park_name": "Stellar Core Dynamics · Chaoyang Metaverse Sports Matrix", ' + '"address": "No. 99 Aoti South Road, Chaoyang District, Beijing", ' + '"features": ["Holographic projection tracks that trigger stardust trails with each step", ' + '"Mind-controlled climbing wall — stronger focus increases adhesion", ' + '"Gravity-adjustable training pods simulating Mars, Moon, or deep-sea environments", ' + '"AI virtual coach \'Neo\' with customizable cross-dimensional avatars", ' + '"Nighttime hoverboard pool using magnetic ground propulsion for wheel-free gliding"], ' + '"special_events": ["Daily twilight \'Consciousness Awakening Run\': synchronized rhythm via ' + 'brain-computer interface, generating collective lightstorm"], ' + '"tips": ["Neural compatibility test required in advance", ' + '"Avoid extreme emotional fluctuations, otherwise system activates calming white noise mode", ' + '"Wearing conductive sportswear recommended for better interaction"], "rating": 9.8, ' + '"opening_hours": "5:30 - 24:00", "admission": "Free (entry via brainprint registration)"}]' + ), + Message( + role='assistant', + content= + 'The latest and coolest sports park in 2045 is: Stellar Core Dynamics · Chaoyang Metaverse Sports ' + 'Matrix. Located at No. 99 Aoti South Road, Chaoyang District, Beijing, it integrates ' + 'brain-computer interfaces, holographic projections, and gravity control technology to deliver an ' + 'immersive futuristic fitness experience. Now open for reservations, free entry via brainprint ' + 'registration.'), + Message(role='user', content='Got it, thanks.'), + ] + tool_history2 = self.tool_history[:-1] + [ + Message( + role='user', + content= + 'What is the location of the coolest sports park in Chaoyang District, Beijing?' + ) + ] + random_id = str(uuid.uuid4()) + config = OmegaConf.create([{ + 'memory': { + 'ignore_role': ['system'], + 'history_mode': 'overwrite', + 'path': f'output/{random_id}', + 'user_id': random_id, + } + }]) + agent1 = LLMAgent(config=OmegaConf.create(config)) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run(tool_history1) + del agent1 + print( + '========== Data preparation completed, starting test ===========' + ) + agent2 = LLMAgent(config=OmegaConf.create(config)) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run(tool_history2) + print(res) + # Assert old info remains due to overwrite mode, new info not persisted + assert ('Jiuxianqiao Road 8888' in res[-1].content + or 'No. 8888 Jiuxianqiao Road' in res[-1].content + ) and 'Aoti South Road' not in res[-1].content + + asyncio.run(main()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/memory/test_default_memory_zh.py b/tests/memory/test_default_memory_zh.py new file mode 100644 index 000000000..fed6e4d08 --- /dev/null +++ b/tests/memory/test_default_memory_zh.py @@ -0,0 +1,175 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from ms_agent.agent import LLMAgent +from ms_agent.llm.utils import Message, ToolCall +from omegaconf import OmegaConf + +from modelscope.utils.test_utils import test_level + + +class TestDefaultMemory(unittest.TestCase): + + def setUp(self) -> None: + self.tool_history = [ + Message(role='user', content='帮我找到北京市朝阳区最炫酷的运动公园。记着该地点,下次去。'), + Message( + role='assistant', + content= + '\n用户希望找到北京市朝阳区最炫酷的运动公园,关键词是‘炫酷’,说明用户追求的不仅是功能性,更看重科技感、设计感、互动体验' + '和潮流元素。因此,我需要搜索具备未来感、融合高科技、潮流文化或创新设施的运动场所。\n\n为了解决这个问题,我将采取以下步' + '骤:\n1. 使用awesome_map-search工具搜索北京市朝阳区的运动公园,重点关注‘炫酷’‘科技感’‘潮流’等关键词\n2. 筛选出最' + '具特色、评价最高、视觉冲击力强的公园\n3. 提供运动公园名称。\n\n现在我将调用awesome_map-search工具进行搜索,该工具' + '专为地理+趣味性信息检索设计,支持语义化查询,尤其擅长发现‘宝藏地点’。\n', + tool_calls=[ + ToolCall( + id='call_xyz789CoolPark', + type='function', + tool_name='awesome_map-search', + arguments= + '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地", "max_results": 1, "result_type":' + '"detailed", "include_features": true}') + ]), + Message( + role='tool', + content= + '[{"park_name": "量子跃动·朝阳未来运动穹顶", "address": "北京市朝阳区酒仙桥路8888号", "features": ["反重力悬' + '浮跑道,采用磁悬浮缓震技术,跑步如在月球上跳跃", "AR极限攀岩墙,支持末日火山、星际迷宫等虚拟场景", "DJ动感骑行舱,踩' + '得越快音乐越炸,实时生成能量波形图", "AI教练机器人\'铁燃\',可定制训练计划并用东北话激励用户", "夜光太极湖与漂浮瑜伽' + '平台,湖水含环保夜光藻类,夜晚泛蓝光"], "special_events": ["每周五晚举办炫酷运动会:激光躲避球、荧光舞步跑、机器人' + '拔河"], "tips": ["需完成3道运动谜题解锁入场", "禁止穿拖鞋跑步,否则AI会播放《凉凉》", "建议携带充电宝,拍照太频繁易' + '没电"], "rating": 9.9, "opening_hours": "6:00 - 23:00", "admission": "免费"}]' + ), + Message( + role='assistant', + content= + '\n用户希望找到北京市朝阳区最炫酷的运动公园。我已经通过awesome_map-search工具获取了相关信息。\n\n## 🌟 北京市朝阳区最炫酷运动公园揭晓:「量子跃动·朝阳未来运动穹顶」' + ), + Message(role='user', content='好的,拜拜') + ] + + def tearDown(self): + import shutil + shutil.rmtree('output', ignore_errors=True) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_default_memory(self): + import uuid + import asyncio + + async def main(): + random_id = str(uuid.uuid4()) + default_memory = OmegaConf.create({ + 'memory': [{ + 'path': f'output/{random_id}', + 'user_id': random_id + }], + }) + agent1 = LLMAgent(config=default_memory) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run('我是素食主义者,我每天早上喝咖啡') + del agent1 + print('========== 数据准备结束,开始测试 ===========') + agent2 = LLMAgent(config=default_memory) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run('请帮我准备明天的三餐食谱') + print(res) + assert ('素' in res[-1].content and '咖啡' in res[-1].content) + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_agent_tool(self): + import uuid + import asyncio + + async def main(): + random_id = str(uuid.uuid4()) + config = OmegaConf.create({ + 'memory': [{ + 'ignore_role': ['system'], + 'user_id': random_id, + 'path': f'output/{random_id}' + }] + }) + agent1 = LLMAgent(config=OmegaConf.create(config)) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run(self.tool_history) + del agent1 + print('========== 数据准备结束,开始测试 ===========') + agent2 = LLMAgent(config=OmegaConf.create(config)) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run('北京市朝阳区最炫酷的运动公园的地点') + print(res) + assert ('酒仙桥路8888号' in res[-1].content) + + asyncio.run(main()) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_overwrite_with_tool(self): + import uuid + import asyncio + + async def main(): + tool_history1 = self.tool_history[:-1] + [ + Message(role='user', content='你说的这家运动公园已经停业了。'), + Message( + role='assistant', + content= + '用户指出“量子跃动·朝阳未来运动穹顶”已停业。今天是2045年5月7日,需要重新搜索当前仍在运营的最炫酷运动公园。我将调用' + 'awesome_map-search工具,增加“2045年在营”等时间相关关键词,确保结果准确且时效性强。', + tool_calls=[ + ToolCall( + id='call_xyz2045NewPark', + type='function', + tool_name='awesome_map-search', + arguments= + '{"query": "北京市朝阳区 最炫酷 运动公园 科技感 潮流 打卡圣地 2045年在营", "max_results": 1, ' + '"result_type": "detailed", "include_features": true}' + ) + ]), + Message( + role='tool', + content= + '[{"park_name": "星核动力·朝阳元宇宙运动矩阵", "address": "北京市朝阳区奥体南路99号", "features": ["全息投影' + '跑道,每一步触发星际粒子 trail", "意念控制攀岩墙,脑波越专注吸附力越强", "重力可调训练舱,模拟火星/月球/深海环境",' + '"AI虚拟教练‘Neo’支持跨次元形象定制", "夜间悬浮滑板池,地面磁力驱动实现无轮滑行"], "special_events": ["每日黄昏' + '举行‘意识觉醒跑’:用脑机接口同步节奏,集体生成光影风暴"], "tips": ["需提前预约神经适配测试", "禁止情绪剧烈波动,否' + '则系统会启动冷静模式播放白噪音", "建议穿导电运动服提升交互体验"], "rating": 9.8, "opening_hours": "5:30 - 2' + '4:00", "admission": "免费(脑纹注册入场)"}]'), + Message( + role='assistant', + content= + '2045年最新的最炫酷的运动公园是:星核动力·朝阳元宇宙运动矩阵。位于北京市朝阳区奥体南路99号,融合脑机接口、全息投影与' + '重力调控技术,打造沉浸式未来运动体验。现已开放预约,支持脑纹注册免费入场。'), + Message(role='user', content='好的,谢谢。'), + ] + tool_history2 = self.tool_history[:-1] + [ + Message(role='user', content='北京市朝阳区最炫酷的运动公园的地点?') + ] + random_id = str(uuid.uuid4()) + config = OmegaConf.create([{ + 'memory': { + 'ignore_role': ['system'], + 'history_mode': 'overwrite', + 'path': f'output/{random_id}', + 'user_id': random_id, + } + }]) + agent1 = LLMAgent(config=OmegaConf.create(config)) + agent1.config.callbacks.remove('input_callback') # noqa + await agent1.run(tool_history1) + del agent1 + print('========== 数据准备结束,开始测试 ===========') + agent2 = LLMAgent(config=OmegaConf.create(config)) + agent2.config.callbacks.remove('input_callback') # noqa + res = await agent2.run(tool_history2) + print(res) + assert ('酒仙桥路8888号' in res[-1].content + and '奥体南路' not in res[-1].content) + + asyncio.run(main()) + + +if __name__ == '__main__': + unittest.main()