Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 47 additions & 20 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import time
import traceback

from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any

from fastapi import APIRouter, HTTPException
Expand Down Expand Up @@ -246,6 +248,7 @@ def search_memories(search_req: APISearchRequest):
@router.post("/add", summary="Add memories", response_model=MemoryResponse)
def add_memories(add_req: APIADDRequest):
"""Add memories for a specific user."""
init_time = time.time()
# Create UserContext object - how to assign values
user_context = UserContext(
user_id=add_req.user_id,
Expand All @@ -256,35 +259,59 @@ def add_memories(add_req: APIADDRequest):
target_session_id = add_req.session_id
if not target_session_id:
target_session_id = "default_session"
memories = mem_reader.get_memory(
[add_req.messages],
type="chat",
info={
"user_id": add_req.user_id,
"session_id": target_session_id,
},
)

# Flatten memory list
flattened_memories = [mm for m in memories for mm in m]
logger.info(f"Memory extraction completed for user {add_req.user_id}")
mem_id_list: list[str] = naive_mem_cube.text_mem.add(
flattened_memories,
user_name=user_context.mem_cube_id,
)
def process_mode(mode: str, scopes: list[str]):
try:
memories = mem_reader.get_memory(
[add_req.messages],
type="chat",
info={
"user_id": add_req.user_id,
"session_id": target_session_id,
},
mode=None if mode == "normal" else mode,
)
flattened = [m for group in memories for m in group]
logger.info(f"[{mode}] Extracted {len(flattened)} memories")

mem_id_list = naive_mem_cube.text_mem.add(
flattened,
user_name=user_context.mem_cube_id,
scope=scopes,
)
logger.info(
f"[{mode}] Added {len(mem_id_list)} memories {mem_id_list}"
f"in session {add_req.session_id}: {mem_id_list}"
)
logger.debug(
f"Time add for mode {mode} is "
f"{round(time.time() - init_time)}"
f"in session {add_req.session_id}: {mem_id_list}"
)

return list(zip(mem_id_list, flattened, strict=False))
except Exception as e:
logger.error(f"[{mode}] Failed: {e} {traceback.format_exc()}")
return []

with ThreadPoolExecutor(max_workers=2) as executor:
futures = {
executor.submit(process_mode, "normal", ["LongTermMemory", "UserMemory"]): "normal",
executor.submit(process_mode, "fast", ["WorkingMemory"]): "fast",
}
all_results = []
for future in as_completed(futures):
all_results.extend(future.result())

logger.info(
f"Added {len(mem_id_list)} memories for user {add_req.user_id} "
f"in session {add_req.session_id}: {mem_id_list}"
)
response_data = [
{
"memory": memory.memory,
"memory_id": memory_id,
"memory_type": memory.metadata.memory_type,
}
for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False)
for memory_id, memory in all_results
]
logger.info(f"[mem_add] All modes done: {len(response_data)} memories total")
return MemoryResponse(
message="Memory added successfully",
data=response_data,
Expand Down
9 changes: 7 additions & 2 deletions src/memos/memories/textual/simple_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def __init__(
logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}")

def add(
self, memories: list[TextualMemoryItem | dict[str, Any]], user_name: str | None = None
self,
memories: list[TextualMemoryItem | dict[str, Any]],
user_name: str | None = None,
scope=None,
) -> list[str]:
"""Add memories.
Args:
Expand All @@ -91,7 +94,9 @@ def add(
plan = plan_memory_operations(memory_items, metadata, self.graph_store)
execute_plan(memory_items, metadata, plan, self.graph_store)
"""
return self.memory_manager.add(memories, user_name=user_name)
if scope is None:
scope = ["LongTermMemory", "UserMemory", "WorkingMemory"]
return self.memory_manager.add(memories, user_name=user_name, scope=scope)

def replace_working_memory(
self, memories: list[TextualMemoryItem], user_name: str | None = None
Expand Down
8 changes: 6 additions & 2 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,16 @@ def __init__(self, config: TreeTextMemoryConfig):
else:
logger.info("No internet retriever configured")

def add(self, memories: list[TextualMemoryItem | dict[str, Any]], **kwargs) -> list[str]:
def add(
self, memories: list[TextualMemoryItem | dict[str, Any]], scope=None, **kwargs
) -> list[str]:
"""Add memories.
Args:
memories: List of TextualMemoryItem objects or dictionaries to add.
"""
return self.memory_manager.add(memories, mode=self.mode)
if scope is None:
scope = ["LongTermMemory", "UserMemory", "WorkingMemory"]
return self.memory_manager.add(memories, mode=self.mode, scope=scope)

def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None:
self.memory_manager.replace_working_memory(memories)
Expand Down
31 changes: 24 additions & 7 deletions src/memos/memories/textual/tree_text_memory/organize/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,24 @@ def __init__(
self._merged_threshold = merged_threshold

def add(
self, memories: list[TextualMemoryItem], user_name: str | None = None, mode: str = "sync"
self,
memories: list[TextualMemoryItem],
user_name: str | None = None,
mode: str = "sync",
scope: list | None = None,
) -> list[str]:
"""
Add new memories in parallel to different memory types.
"""
if scope is None:
scope = ["LongTermMemory", "UserMemory", "WorkingMemory"]

added_ids: list[str] = []

with ContextThreadPoolExecutor(max_workers=20) as executor:
futures = {executor.submit(self._process_memory, m, user_name): m for m in memories}
futures = {
executor.submit(self._process_memory, m, user_name, scope): m for m in memories
}
for future in as_completed(futures, timeout=60):
try:
ids = future.result()
Expand Down Expand Up @@ -127,19 +136,27 @@ def _refresh_memory_size(self, user_name: str | None = None) -> None:
self.current_memory_size = {record["memory_type"]: record["count"] for record in results}
logger.info(f"[MemoryManager] Refreshed memory sizes: {self.current_memory_size}")

def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = None):
def _process_memory(
self, memory: TextualMemoryItem, user_name: str | None = None, scope: list | None = None
):
"""
Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory).
This method runs asynchronously to process each memory item.
"""
if scope is None:
scope = ["LongTermMemory", "UserMemory", "WorkingMemory"]
ids: list[str] = []
futures = []

with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex:
f_working = ex.submit(self._add_memory_to_db, memory, "WorkingMemory", user_name)
futures.append(f_working)

if memory.metadata.memory_type in ("LongTermMemory", "UserMemory"):
if "WorkingMemory" in scope:
f_working = ex.submit(self._add_memory_to_db, memory, "WorkingMemory", user_name)
futures.append(f_working)

if (
memory.metadata.memory_type in ("LongTermMemory", "UserMemory")
and memory.metadata.memory_type in scope
):
f_graph = ex.submit(
self._add_to_graph_memory,
memory=memory,
Expand Down
10 changes: 7 additions & 3 deletions tests/memories/textual/test_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid

from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch

import pytest

Expand Down Expand Up @@ -66,7 +66,9 @@ def test_add_calls_manager(mock_tree_text_memory):
metadata=TreeNodeTextualMemoryMetadata(updated_at=None),
)
mock_tree_text_memory.add([mock_item])
mock_tree_text_memory.memory_manager.add.assert_called_once_with([mock_item], mode="sync")
mock_tree_text_memory.memory_manager.add.assert_called_once_with(
[mock_item], mode="sync", scope=ANY
)


def test_get_working_memory_sorted(mock_tree_text_memory):
Expand Down Expand Up @@ -161,4 +163,6 @@ def test_add_returns_ids(mock_tree_text_memory):
result = mock_tree_text_memory.add(mock_items)

assert result == dummy_ids
mock_tree_text_memory.memory_manager.add.assert_called_once_with(mock_items, mode="sync")
mock_tree_text_memory.memory_manager.add.assert_called_once_with(
mock_items, mode="sync", scope=ANY
)