diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index a332de583..5b9d32caf 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,6 +1,8 @@ import os +import time import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any from fastapi import APIRouter, HTTPException @@ -246,6 +248,7 @@ def search_memories(search_req: APISearchRequest): @router.post("/add", summary="Add memories", response_model=MemoryResponse) def add_memories(add_req: APIADDRequest): """Add memories for a specific user.""" + init_time = time.time() # Create UserContext object - how to assign values user_context = UserContext( user_id=add_req.user_id, @@ -256,35 +259,59 @@ def add_memories(add_req: APIADDRequest): target_session_id = add_req.session_id if not target_session_id: target_session_id = "default_session" - memories = mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - ) - # Flatten memory list - flattened_memories = [mm for m in memories for mm in m] - logger.info(f"Memory extraction completed for user {add_req.user_id}") - mem_id_list: list[str] = naive_mem_cube.text_mem.add( - flattened_memories, - user_name=user_context.mem_cube_id, - ) + def process_mode(mode: str, scopes: list[str]): + try: + memories = mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + mode=None if mode == "normal" else mode, + ) + flattened = [m for group in memories for m in group] + logger.info(f"[{mode}] Extracted {len(flattened)} memories") + + mem_id_list = naive_mem_cube.text_mem.add( + flattened, + user_name=user_context.mem_cube_id, + scope=scopes, + ) + logger.info( + f"[{mode}] Added {len(mem_id_list)} memories {mem_id_list}" + f"in session {add_req.session_id}: {mem_id_list}" + ) + logger.debug( + f"Time add for mode {mode} is " + f"{round(time.time() - init_time)}" + f"in session {add_req.session_id}: {mem_id_list}" + ) + + return list(zip(mem_id_list, flattened, strict=False)) + except Exception as e: + logger.error(f"[{mode}] Failed: {e} {traceback.format_exc()}") + return [] + + with ThreadPoolExecutor(max_workers=2) as executor: + futures = { + executor.submit(process_mode, "normal", ["LongTermMemory", "UserMemory"]): "normal", + executor.submit(process_mode, "fast", ["WorkingMemory"]): "fast", + } + all_results = [] + for future in as_completed(futures): + all_results.extend(future.result()) - logger.info( - f"Added {len(mem_id_list)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_id_list}" - ) response_data = [ { "memory": memory.memory, "memory_id": memory_id, "memory_type": memory.metadata.memory_type, } - for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False) + for memory_id, memory in all_results ] + logger.info(f"[mem_add] All modes done: {len(response_data)} memories total") return MemoryResponse( message="Memory added successfully", data=response_data, diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 9c67db288..e72968777 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -80,7 +80,10 @@ def __init__( logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}") def add( - self, memories: list[TextualMemoryItem | dict[str, Any]], user_name: str | None = None + self, + memories: list[TextualMemoryItem | dict[str, Any]], + user_name: str | None = None, + scope=None, ) -> list[str]: """Add memories. Args: @@ -91,7 +94,9 @@ def add( plan = plan_memory_operations(memory_items, metadata, self.graph_store) execute_plan(memory_items, metadata, plan, self.graph_store) """ - return self.memory_manager.add(memories, user_name=user_name) + if scope is None: + scope = ["LongTermMemory", "UserMemory", "WorkingMemory"] + return self.memory_manager.add(memories, user_name=user_name, scope=scope) def replace_working_memory( self, memories: list[TextualMemoryItem], user_name: str | None = None diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index fccd83fa6..0970a6320 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -81,12 +81,16 @@ def __init__(self, config: TreeTextMemoryConfig): else: logger.info("No internet retriever configured") - def add(self, memories: list[TextualMemoryItem | dict[str, Any]], **kwargs) -> list[str]: + def add( + self, memories: list[TextualMemoryItem | dict[str, Any]], scope=None, **kwargs + ) -> list[str]: """Add memories. Args: memories: List of TextualMemoryItem objects or dictionaries to add. """ - return self.memory_manager.add(memories, mode=self.mode) + if scope is None: + scope = ["LongTermMemory", "UserMemory", "WorkingMemory"] + return self.memory_manager.add(memories, mode=self.mode, scope=scope) def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: self.memory_manager.replace_working_memory(memories) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 54776134b..05b2cda1d 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -53,15 +53,24 @@ def __init__( self._merged_threshold = merged_threshold def add( - self, memories: list[TextualMemoryItem], user_name: str | None = None, mode: str = "sync" + self, + memories: list[TextualMemoryItem], + user_name: str | None = None, + mode: str = "sync", + scope: list | None = None, ) -> list[str]: """ Add new memories in parallel to different memory types. """ + if scope is None: + scope = ["LongTermMemory", "UserMemory", "WorkingMemory"] + added_ids: list[str] = [] with ContextThreadPoolExecutor(max_workers=20) as executor: - futures = {executor.submit(self._process_memory, m, user_name): m for m in memories} + futures = { + executor.submit(self._process_memory, m, user_name, scope): m for m in memories + } for future in as_completed(futures, timeout=60): try: ids = future.result() @@ -127,19 +136,27 @@ def _refresh_memory_size(self, user_name: str | None = None) -> None: self.current_memory_size = {record["memory_type"]: record["count"] for record in results} logger.info(f"[MemoryManager] Refreshed memory sizes: {self.current_memory_size}") - def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = None): + def _process_memory( + self, memory: TextualMemoryItem, user_name: str | None = None, scope: list | None = None + ): """ Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory). This method runs asynchronously to process each memory item. """ + if scope is None: + scope = ["LongTermMemory", "UserMemory", "WorkingMemory"] ids: list[str] = [] futures = [] with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: - f_working = ex.submit(self._add_memory_to_db, memory, "WorkingMemory", user_name) - futures.append(f_working) - - if memory.metadata.memory_type in ("LongTermMemory", "UserMemory"): + if "WorkingMemory" in scope: + f_working = ex.submit(self._add_memory_to_db, memory, "WorkingMemory", user_name) + futures.append(f_working) + + if ( + memory.metadata.memory_type in ("LongTermMemory", "UserMemory") + and memory.metadata.memory_type in scope + ): f_graph = ex.submit( self._add_to_graph_memory, memory=memory, diff --git a/tests/memories/textual/test_tree.py b/tests/memories/textual/test_tree.py index 772a79d78..166bad633 100644 --- a/tests/memories/textual/test_tree.py +++ b/tests/memories/textual/test_tree.py @@ -1,6 +1,6 @@ import uuid -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest @@ -66,7 +66,9 @@ def test_add_calls_manager(mock_tree_text_memory): metadata=TreeNodeTextualMemoryMetadata(updated_at=None), ) mock_tree_text_memory.add([mock_item]) - mock_tree_text_memory.memory_manager.add.assert_called_once_with([mock_item], mode="sync") + mock_tree_text_memory.memory_manager.add.assert_called_once_with( + [mock_item], mode="sync", scope=ANY + ) def test_get_working_memory_sorted(mock_tree_text_memory): @@ -161,4 +163,6 @@ def test_add_returns_ids(mock_tree_text_memory): result = mock_tree_text_memory.add(mock_items) assert result == dummy_ids - mock_tree_text_memory.memory_manager.add.assert_called_once_with(mock_items, mode="sync") + mock_tree_text_memory.memory_manager.add.assert_called_once_with( + mock_items, mode="sync", scope=ANY + )