-
Notifications
You must be signed in to change notification settings - Fork 406
Memory #721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Memory #721
Changes from 2 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
83e62b3
runable
7a82425
fix lint
e250d3c
update
123ca32
agent runable
dba0c39
test agent default memory runabl
d6910bd
support ignore_role
845f17d
minor fix
228ca13
feat: modify history messages
9f9ed65
minor fix
c1fb41e
fix typo
f91203c
fix time update
27d8502
fix comment & adjust for conficts in advance
9051d16
fix bugs
f1a92c6
minor fix
4a30dd8
add en test case
0a9ebdd
Merge branch 'modelscope:main' into memory
suluyana 04f18cf
Merge remote-tracking branch 'origin' into memory
0b87e2a
fix conflicts
3b1343e
minor fix
fc3b724
Merge remote-tracking branch 'origin' into memory
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,293 @@ | ||
| # Copyright (c) Alibaba, Inc. and its affiliates. | ||
| from copy import deepcopy | ||
| from typing import Any, Dict, List, Literal, Optional, Set, Tuple | ||
|
|
||
| from langchain.chains.question_answering.map_reduce_prompt import messages | ||
suluyana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from mem0 import Memory as Mem0Memory | ||
| from ms_agent.agent.memory import Memory | ||
| from ms_agent.llm.utils import Message | ||
| from ms_agent.utils.logger import logger | ||
| from ms_agent.utils.prompts import FACT_RETRIEVAL_PROMPT | ||
| from omegaconf import DictConfig, OmegaConf | ||
|
|
||
|
|
||
| class DefaultMemory(Memory): | ||
| """The memory refine tool""" | ||
|
|
||
| def __init__(self, | ||
| config: DictConfig, | ||
| cache_messages: Optional[List[Message]] = None, | ||
| conversation_id: Optional[str] = None, | ||
| persist: bool = False, | ||
| path: str = None, | ||
| history_mode: Literal['add', 'overwrite'] = 'overwrite', | ||
| current_memory_cache_position: int = 0): | ||
| super().__init__(config) | ||
| self.cache_messages = cache_messages | ||
| self.conversation_id: Optional[str] = conversation_id or getattr( | ||
| config.memory, 'conversation_id', None) | ||
| self.persist: Optional[bool] = persist or getattr( | ||
| config.memory, 'persist', None) | ||
| self.compress: Optional[bool] = getattr(config.memory, 'compress', | ||
| None) | ||
| self.embedder: Optional[str] = getattr(config.memory, 'embedder', None) | ||
| self.is_retrieve: Optional[bool] = getattr(config.memory, | ||
| 'is_retrieve', None) | ||
| self.path: Optional[str] = path or getattr(config.memory, 'path', None) | ||
| self.history_mode = history_mode or getattr(config.memory, | ||
| 'history_mode', None) | ||
| self.current_memory_cache_position = current_memory_cache_position | ||
| self.memory = self._init_memory() | ||
|
|
||
| def _should_update_memory(self, messages: List[Message]) -> bool: | ||
| return True | ||
|
|
||
| def _find_messages_common_prefix( | ||
| self, | ||
| messages: List[Dict], | ||
| ignore_role: Optional[Set[str]] = {'system'}, | ||
| ignore_fields: Optional[Set[str]] = {'reasoning_content'}, | ||
| ) -> Tuple[List[Dict], int, int]: | ||
| """ | ||
| 比对 messages 和缓存messages的差异,并提取最长公共前缀。 | ||
| Args: | ||
| messages: 本次 List[Dict],符合 OpenAI API 格式 | ||
| ignore_role: 是否忽略 role="system"、或者role="tool" 的message | ||
| ignore_fields: 可选,要忽略比较的字段名集合,如 {"reasoning_content"} | ||
| Returns: | ||
| 最长公共前缀(List[Dict]) | ||
| """ | ||
suluyana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not messages or not isinstance(messages, list): | ||
| return [], -1, -1 | ||
|
|
||
| if ignore_fields is None: | ||
| ignore_fields = set() | ||
|
|
||
| # 预处理:根据 ignore_role 过滤消息 | ||
| def _ignore_role(msgs): | ||
| filtered = [] | ||
| indices = [] # 每个 filtered 消息对应的原始索引 | ||
| for idx, msg in enumerate(msgs): | ||
| if ignore_role and msg.get('role') in ignore_role: | ||
| continue | ||
| filtered.append(msg) | ||
| indices.append(idx) | ||
| return filtered, indices | ||
|
|
||
| filtered_messages, indices = _ignore_role(messages) | ||
| filtered_cache_messages, cache_indices = _ignore_role( | ||
| self.cache_messages) | ||
|
|
||
| # 找最短长度,避免越界 | ||
| min_length = min( | ||
| len(msgs) for msgs in [filtered_messages, filtered_cache_messages]) | ||
| common_prefix = [] | ||
|
|
||
| idx = 0 | ||
| for idx in range(min_length): | ||
| current_cache_msg = filtered_cache_messages[idx] | ||
| current_msg = filtered_messages[idx] | ||
| is_common = True | ||
|
|
||
| # 比较其他字段(除了忽略的字段) | ||
| all_keys = set(current_cache_msg.keys()).union( | ||
| set(current_msg.keys())) | ||
| for key in all_keys: | ||
| if key in ignore_fields: | ||
| continue | ||
| if current_cache_msg.get(key) != current_msg.get(key): | ||
| is_common = False | ||
| break | ||
|
|
||
| if not is_common: | ||
| break | ||
|
|
||
| # 添加当前消息的深拷贝到结果中(保留原始结构) | ||
| common_prefix.append(deepcopy(current_msg)) | ||
|
|
||
| if len(common_prefix) == 0: | ||
| return [], -1, -1 | ||
|
|
||
| return common_prefix, indices[idx], cache_indices[idx] | ||
|
|
||
| def rollback(self, common_prefix_messages, cache_message_idx): | ||
| # 支持retry机制,将memory回退到 self.cache_messages的第idx 条message | ||
| if self.history_mode == 'add': | ||
| # 只有覆盖更新模式才支持回退;回退涉及删除 | ||
| return | ||
| # TODO: 真正的回退 | ||
suluyana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.memory.delete_all(user_id=self.conversation_id) | ||
| self.memory.add(common_prefix_messages, user_id=self.conversation_id) | ||
|
|
||
| def run(self, messages, ignore_role=None, ignore_fields=None): | ||
| print( | ||
| f'ahahahah?1 : {self.memory.get_all(user_id=self.conversation_id)}' | ||
| ) | ||
suluyana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not self.cache_messages: | ||
| self.cache_messages = messages | ||
| common_prefix_messages, messages_idx, cache_message_idx\ | ||
| = self._find_messages_common_prefix(messages, | ||
| ignore_role=ignore_role, | ||
| ignore_fields=ignore_fields) | ||
| print( | ||
| f'ahahahah?2 : {self.memory.get_all(user_id=self.conversation_id)}' | ||
| ) | ||
| if not self.is_retrieve or not self._should_update_memory(messages): | ||
| return messages | ||
| print( | ||
| f'ahahahah?3 : {self.memory.get_all(user_id=self.conversation_id)}' | ||
| ) | ||
| if self.history_mode == 'add': | ||
| print( | ||
| f'ahahahah?4 : {self.memory.get_all(user_id=self.conversation_id)}' | ||
| ) | ||
| self.memory.add(messages, user_id=self.conversation_id) | ||
| res = self.memory.get_all(user_id=self.conversation_id) | ||
| print(f'res: {res}') | ||
| else: | ||
| print( | ||
| f'ahahahah?5 : {self.memory.get_all(user_id=self.conversation_id)}' | ||
| ) | ||
| if cache_message_idx < len(self.cache_messages): | ||
| self.rollback(common_prefix_messages, cache_message_idx) | ||
| self.cache_messages = messages | ||
| print(f'messages: {messages}') | ||
| self.memory.add( | ||
| messages[messages_idx:], user_id=self.conversation_id) | ||
| res = self.memory.get_all(user_id=self.conversation_id) | ||
| print(f'res: {res}') | ||
| print(f'messages[-1]["content"]: {messages[-1]["content"]}') | ||
| relevant_memories = self.memory.search( | ||
| messages[-1]['content'], user_id=self.conversation_id, limit=3) | ||
| memories_str = '\n'.join(f"- {entry['memory']}" | ||
| for entry in relevant_memories['results']) | ||
| print(f'memories_str: {memories_str}') | ||
| # 将memory对应的messages段删除,并添加相关的memory_str信息 | ||
| if messages[0].get('role') == 'system': | ||
| system_prompt = messages[0][ | ||
| 'content'] + f'\nUser Memories: {memories_str}' | ||
| else: | ||
| system_prompt = f'\nYou are a helpful assistant. Answer the question based on query and memories.\nUser Memories: {memories_str}' | ||
| new_messages = [{ | ||
| 'role': 'system', | ||
| 'content': system_prompt | ||
| }] + messages[messages_idx:] | ||
|
|
||
| return new_messages | ||
|
|
||
| def _init_memory(self) -> Mem0Memory | None: | ||
| if not self.is_retrieve: | ||
| return | ||
|
|
||
| if self.embedder is None: | ||
| # TODO: set default | ||
suluyana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise ValueError('embedder must be set when is_retrieve=True.') | ||
| embedder = self.embedder | ||
|
|
||
| llm = {} | ||
| if self.compress: | ||
| llm_config = getattr(self.config.memory, 'llm', None) | ||
| if llm_config is not None: | ||
| # follow mem0 config | ||
| model = llm_config.get('model') | ||
| provider = llm_config.get('provider', 'openai') | ||
| openai_base_url = llm_config.get('openai_base_url', None) | ||
| openai_api_key = llm_config.get('api_key', None) | ||
| else: | ||
| llm_config = self.config.llm | ||
| model = llm_config.model | ||
| provider = llm_config.service | ||
| openai_base_url = getattr(llm_config, f'{provider}_base_url', | ||
| None) | ||
| openai_api_key = getattr(llm_config, f'{provider}_api_key', | ||
| None) | ||
|
|
||
| llm = { | ||
| 'provider': provider, | ||
| 'config': { | ||
| 'model': model, | ||
| 'openai_base_url': openai_base_url, | ||
| 'api_key': openai_api_key | ||
| } | ||
| } | ||
|
|
||
| mem0_config = { | ||
| 'is_infer': | ||
| self.compress, | ||
| 'llm': | ||
| llm, | ||
| 'custom_fact_extraction_prompt': | ||
| getattr(self.config.memory, 'fact_retrieval_prompt', | ||
| FACT_RETRIEVAL_PROMPT), | ||
| 'vector_store': { | ||
| 'provider': 'qdrant', | ||
| 'config': { | ||
| 'path': self.path, | ||
| # "on_disk": self.persist | ||
| 'on_disk': True | ||
| } | ||
| }, | ||
| 'embedder': | ||
| embedder | ||
| } | ||
| #logger.info(f'Memory config: {mem0_config}') | ||
| memory = Mem0Memory.from_config(mem0_config) | ||
| memory.add(self.cache_messages, user_id=self.conversation_id) | ||
| res = memory.get_all(user_id=self.conversation_id) | ||
| print(f'res: {res}') | ||
| return memory | ||
|
|
||
|
|
||
| async def main(): | ||
| import os | ||
| import json | ||
| cfg = { | ||
| 'memory': { | ||
| 'conversation_id': 'default_id', | ||
| 'persist': True, | ||
| 'compress': True, | ||
| 'is_retrieve': True, | ||
| 'history_mode': 'add', | ||
| # "embedding_model": "text-embedding-v4", | ||
| 'llm': { | ||
| 'provider': 'openai', | ||
| 'model': 'qwen3-235b-a22b-instruct-2507', | ||
| 'openai_base_url': | ||
| 'https://dashscope.aliyuncs.com/compatible-mode/v1', | ||
| 'api_key': os.getenv('DASHSCOPE_API_KEY'), | ||
| }, | ||
| 'embedder': { | ||
| 'provider': 'openai', | ||
| 'config': { | ||
| 'api_key': os.getenv('DASHSCOPE_API_KEY'), | ||
| 'openai_base_url': | ||
| 'https://dashscope.aliyuncs.com/compatible-mode/v1', | ||
| 'model': 'text-embedding-v4', | ||
| } | ||
| } | ||
| # "vector_store": { | ||
| # "provider": "qdrant", | ||
| # "config": { | ||
| # "path": "/Users/luyan/workspace/mem0/storage", | ||
| # "on_disk": False | ||
| # } | ||
| # } | ||
| } | ||
| } | ||
| with open('openai_format_test_case1.json', 'r') as f: | ||
| data = json.load(f) | ||
| config = OmegaConf.create(cfg) | ||
| memory = DefaultMemory( | ||
| config, path='./output', cache_messages=data, history_mode='add') | ||
| res = memory.run(messages=[{ | ||
| 'role': 'user', | ||
| 'content': '使用bun会对新项目的影响大吗,有哪些新特性' | ||
| }]) | ||
| print(res) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| import asyncio | ||
| asyncio.run(main()) | ||
suluyana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add function comments for key functions in the form of:
"""
Prepare memory ...
Args:
...
Returns:
...
Raises: (Optional)
....
"""