diff --git a/examples/mem_api/pipeline_test.py b/examples/mem_api/pipeline_test.py new file mode 100644 index 000000000..cd7b3bee3 --- /dev/null +++ b/examples/mem_api/pipeline_test.py @@ -0,0 +1,178 @@ +""" +Pipeline test script for MemOS Server API functions. +This script directly tests add and search functionalities without going through the API layer. +If you want to start server_api set .env to MemOS/.env and run: +uvicorn memos.api.server_api:app --host 0.0.0.0 --port 8002 --workers 4 +""" + +from typing import Any + +from dotenv import load_dotenv + +# Import directly from server_router to reuse initialized components +from memos.api.routers.server_router import ( + _create_naive_mem_cube, + mem_reader, +) +from memos.log import get_logger + + +# Load environment variables +load_dotenv() + +logger = get_logger(__name__) + + +def test_add_memories( + messages: list[dict[str, str]], + user_id: str, + mem_cube_id: str, + session_id: str = "default_session", +) -> list[str]: + """ + Test adding memories to the system. + + Args: + messages: List of message dictionaries with 'role' and 'content' + user_id: User identifier + mem_cube_id: Memory cube identifier + session_id: Session identifier + + Returns: + List of memory IDs that were added + """ + logger.info(f"Testing add memories for user: {user_id}, mem_cube: {mem_cube_id}") + + # Create NaiveMemCube using server_router function + naive_mem_cube = _create_naive_mem_cube() + + # Extract memories from messages using server_router's mem_reader + memories = mem_reader.get_memory( + [messages], + type="chat", + info={ + "user_id": user_id, + "session_id": session_id, + }, + ) + + # Flatten memory list + flattened_memories = [mm for m in memories for mm in m] + + # Add memories to the system + mem_id_list: list[str] = naive_mem_cube.text_mem.add( + flattened_memories, + user_name=mem_cube_id, + ) + + logger.info(f"Added {len(mem_id_list)} memories: {mem_id_list}") + + # Print details of added memories + for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False): + logger.info(f" - ID: {memory_id}") + logger.info(f" Memory: {memory.memory}") + logger.info(f" Type: {memory.metadata.memory_type}") + + return mem_id_list + + +def test_search_memories( + query: str, + user_id: str, + mem_cube_id: str, + session_id: str = "default_session", + top_k: int = 5, + mode: str = "fast", + internet_search: bool = False, + moscube: bool = False, + chat_history: list | None = None, +) -> list[Any]: + """ + Test searching memories from the system. + + Args: + query: Search query text + user_id: User identifier + mem_cube_id: Memory cube identifier + session_id: Session identifier + top_k: Number of top results to return + mode: Search mode + internet_search: Whether to enable internet search + moscube: Whether to enable moscube search + chat_history: Chat history for context + + Returns: + List of search results + """ + + # Create NaiveMemCube using server_router function + naive_mem_cube = _create_naive_mem_cube() + + # Prepare search filter + search_filter = {"session_id": session_id} if session_id != "default_session" else None + + search_results = naive_mem_cube.text_mem.search( + query=query, + user_name=mem_cube_id, + top_k=top_k, + mode=mode, + manual_close_internet=not internet_search, + moscube=moscube, + search_filter=search_filter, + info={ + "user_id": user_id, + "session_id": session_id, + "chat_history": chat_history or [], + }, + ) + + # Print search results + for idx, result in enumerate(search_results, 1): + logger.info(f"\n Result {idx}:") + logger.info(f" ID: {result.id}") + logger.info(f" Memory: {result.memory}") + logger.info(f" Score: {getattr(result, 'score', 'N/A')}") + logger.info(f" Type: {result.metadata.memory_type}") + + return search_results + + +def main(): + # Test parameters + user_id = "test_user_123" + mem_cube_id = "test_cube_123" + session_id = "test_session_001" + + test_messages = [ + {"role": "user", "content": "Where should I go for Christmas?"}, + { + "role": "assistant", + "content": "There are many places to visit during Christmas, such as the Bund and Disneyland in Shanghai.", + }, + {"role": "user", "content": "What about New Year's Eve?"}, + { + "role": "assistant", + "content": "For New Year's Eve, you could visit Times Square in New York or watch fireworks at the Sydney Opera House.", + }, + ] + + memory_ids = test_add_memories( + messages=test_messages, user_id=user_id, mem_cube_id=mem_cube_id, session_id=session_id + ) + + logger.info(f"\nSuccessfully added {len(memory_ids)} memories!") + + search_queries = [ + "How to enjoy Christmas?", + "Where to celebrate New Year?", + "What are good places to visit during holidays?", + ] + + for query in search_queries: + logger.info("\n" + "-" * 80) + results = test_search_memories(query=query, user_id=user_id, mem_cube_id=mem_cube_id) + print(f"Query: '{query}' returned {len(results)} results") + + +if __name__ == "__main__": + main() diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 7e425415b..eb2d7aa6d 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field # Import message types from core types module -from memos.types import MessageDict +from memos.types import MessageDict, PermissionDict T = TypeVar("T") @@ -164,6 +164,39 @@ class SearchRequest(BaseRequest): session_id: str | None = Field(None, description="Session ID for soft-filtering memories") +class APISearchRequest(BaseRequest): + """Request model for searching memories.""" + + query: str = Field(..., description="Search query") + user_id: str = Field(None, description="User ID") + mem_cube_id: str | None = Field(None, description="Cube ID to search in") + mode: str = Field("fast", description="search mode fast or fine") + internet_search: bool = Field(False, description="Whether to use internet search") + moscube: bool = Field(False, description="Whether to use MemOSCube") + top_k: int = Field(10, description="Number of results to return") + chat_history: list[MessageDict] | None = Field(None, description="Chat history") + session_id: str | None = Field(None, description="Session ID for soft-filtering memories") + operation: list[PermissionDict] | None = Field( + None, description="operation ids for multi cubes" + ) + + +class APIADDRequest(BaseRequest): + """Request model for creating memories.""" + + user_id: str = Field(None, description="User ID") + mem_cube_id: str = Field(..., description="Cube ID") + messages: list[MessageDict] | None = Field(None, description="List of messages to store.") + memory_content: str | None = Field(None, description="Memory content to store") + doc_path: str | None = Field(None, description="Path to document to store") + source: str | None = Field(None, description="Source of the memory") + chat_history: list[MessageDict] | None = Field(None, description="Chat history") + session_id: str | None = Field(None, description="Session id") + operation: list[PermissionDict] | None = Field( + None, description="operation ids for multi cubes" + ) + + class SuggestionRequest(BaseRequest): """Request model for getting suggestion queries.""" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py new file mode 100644 index 000000000..1d398ff72 --- /dev/null +++ b/src/memos/api/routers/server_router.py @@ -0,0 +1,282 @@ +import os + +from typing import Any + +from fastapi import APIRouter + +from memos.api.config import APIConfig +from memos.api.product_models import ( + APIADDRequest, + APISearchRequest, + MemoryResponse, + SearchResponse, +) +from memos.configs.embedder import EmbedderConfigFactory +from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.internet_retriever import InternetRetrieverConfigFactory +from memos.configs.llm import LLMConfigFactory +from memos.configs.mem_reader import MemReaderConfigFactory +from memos.configs.reranker import RerankerConfigFactory +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.llms.factory import LLMFactory +from memos.log import get_logger +from memos.mem_cube.navie import NaiveMemCube +from memos.mem_reader.factory import MemReaderFactory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( + InternetRetrieverFactory, +) +from memos.reranker.factory import RerankerFactory +from memos.types import MOSSearchResult, UserContext + + +logger = get_logger(__name__) + +router = APIRouter(prefix="/product", tags=["Server API"]) + + +def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: + """Build graph database configuration.""" + graph_db_backend_map = { + "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), + "neo4j": APIConfig.get_neo4j_config(user_id=user_id), + "nebular": APIConfig.get_nebular_config(user_id=user_id), + } + + graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + return GraphDBConfigFactory.model_validate( + { + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], + } + ) + + +def _build_llm_config() -> dict[str, Any]: + """Build LLM configuration.""" + return LLMConfigFactory.model_validate( + { + "backend": "openai", + "config": APIConfig.get_openai_config(), + } + ) + + +def _build_embedder_config() -> dict[str, Any]: + """Build embedder configuration.""" + return EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) + + +def _build_mem_reader_config() -> dict[str, Any]: + """Build memory reader configuration.""" + return MemReaderConfigFactory.model_validate( + APIConfig.get_product_default_config()["mem_reader"] + ) + + +def _build_reranker_config() -> dict[str, Any]: + """Build reranker configuration.""" + return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + + +def _build_internet_retriever_config() -> dict[str, Any]: + """Build internet retriever configuration.""" + return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) + + +def _get_default_memory_size(cube_config) -> dict[str, int]: + """Get default memory size configuration.""" + return getattr(cube_config.text_mem.config, "memory_size", None) or { + "WorkingMemory": 20, + "LongTermMemory": 1500, + "UserMemory": 480, + } + + +def init_server(): + """Initialize server components and configurations.""" + # Get default cube configuration + default_cube_config = APIConfig.get_default_cube_config() + + # Build component configurations + graph_db_config = _build_graph_db_config() + print(graph_db_config) + llm_config = _build_llm_config() + embedder_config = _build_embedder_config() + mem_reader_config = _build_mem_reader_config() + reranker_config = _build_reranker_config() + internet_retriever_config = _build_internet_retriever_config() + + # Create component instances + graph_db = GraphStoreFactory.from_config(graph_db_config) + llm = LLMFactory.from_config(llm_config) + embedder = EmbedderFactory.from_config(embedder_config) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + reranker = RerankerFactory.from_config(reranker_config) + internet_retriever = InternetRetrieverFactory.from_config( + internet_retriever_config, embedder=embedder + ) + + # Initialize memory manager + memory_manager = MemoryManager( + graph_db, + embedder, + llm, + memory_size=_get_default_memory_size(default_cube_config), + is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False), + ) + + return ( + graph_db, + mem_reader, + llm, + embedder, + reranker, + internet_retriever, + memory_manager, + default_cube_config, + ) + + +# Initialize global components +( + graph_db, + mem_reader, + llm, + embedder, + reranker, + internet_retriever, + memory_manager, + default_cube_config, +) = init_server() + + +def _create_naive_mem_cube() -> NaiveMemCube: + """Create a NaiveMemCube instance with initialized components.""" + naive_mem_cube = NaiveMemCube( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=reranker, + internet_retriever=internet_retriever, + memory_manager=memory_manager, + default_cube_config=default_cube_config, + ) + return naive_mem_cube + + +def _format_memory_item(memory_data: Any) -> dict[str, Any]: + """Format a single memory item for API response.""" + memory = memory_data.model_dump() + memory_id = memory["id"] + ref_id = f"[{memory_id.split('-')[0]}]" + + memory["ref_id"] = ref_id + memory["metadata"]["embedding"] = [] + memory["metadata"]["sources"] = [] + memory["metadata"]["ref_id"] = ref_id + memory["metadata"]["id"] = memory_id + memory["metadata"]["memory"] = memory["memory"] + + return memory + + +@router.post("/search", summary="Search memories", response_model=SearchResponse) +def search_memories(search_req: APISearchRequest): + """Search memories for a specific user.""" + # Create UserContext object - how to assign values + user_context = UserContext( + user_id=search_req.user_id, + mem_cube_id=search_req.mem_cube_id, + session_id=search_req.session_id or "default_session", + ) + logger.info(f"Search user_id is: {user_context.mem_cube_id}") + memories_result: MOSSearchResult = { + "text_mem": [], + "act_mem": [], + "para_mem": [], + } + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + # Create MemCube and perform search + naive_mem_cube = _create_naive_mem_cube() + search_results = naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=search_req.mode, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + formatted_memories = [_format_memory_item(data) for data in search_results] + + memories_result["text_mem"].append( + { + "cube_id": search_req.mem_cube_id, + "memories": formatted_memories, + } + ) + + return SearchResponse( + message="Search completed successfully", + data=memories_result, + ) + + +@router.post("/add", summary="Add memories", response_model=MemoryResponse) +def add_memories(add_req: APIADDRequest): + """Add memories for a specific user.""" + # Create UserContext object - how to assign values + user_context = UserContext( + user_id=add_req.user_id, + mem_cube_id=add_req.mem_cube_id, + session_id=add_req.session_id or "default_session", + ) + naive_mem_cube = _create_naive_mem_cube() + target_session_id = add_req.session_id + if not target_session_id: + target_session_id = "default_session" + memories = mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + + # Flatten memory list + flattened_memories = [mm for m in memories for mm in m] + logger.info(f"Memory extraction completed for user {add_req.user_id}") + mem_id_list: list[str] = naive_mem_cube.text_mem.add( + flattened_memories, + user_name=user_context.mem_cube_id, + ) + + logger.info( + f"Added {len(mem_id_list)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_id_list}" + ) + response_data = [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False) + ] + return MemoryResponse( + message="Memory added successfully", + data=response_data, + ) diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py new file mode 100644 index 000000000..78e05ef85 --- /dev/null +++ b/src/memos/api/server_api.py @@ -0,0 +1,38 @@ +import logging + +from fastapi import FastAPI + +from memos.api.exceptions import APIExceptionHandler +from memos.api.middleware.request_context import RequestContextMiddleware +from memos.api.routers.server_router import router as server_router + + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +app = FastAPI( + title="MemOS Product REST APIs", + description="A REST API for managing multiple users with MemOS Product.", + version="1.0.1", +) + +app.add_middleware(RequestContextMiddleware) +# Include routers +app.include_router(server_router) + +# Exception handlers +app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler) +app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) + + +if __name__ == "__main__": + import argparse + + import uvicorn + + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=8001) + parser.add_argument("--workers", type=int, default=1) + args = parser.parse_args() + uvicorn.run("memos.api.server_api:app", host="0.0.0.0", port=args.port, workers=args.workers) diff --git a/src/memos/configs/mem_user.py b/src/memos/configs/mem_user.py index 3ff1066e5..6e1ca4206 100644 --- a/src/memos/configs/mem_user.py +++ b/src/memos/configs/mem_user.py @@ -31,6 +31,17 @@ class MySQLUserManagerConfig(BaseUserManagerConfig): charset: str = Field(default="utf8mb4", description="MySQL charset") +class RedisUserManagerConfig(BaseUserManagerConfig): + """Redis user manager configuration.""" + + host: str = Field(default="localhost", description="Redis server host") + port: int = Field(default=6379, description="Redis server port") + username: str = Field(default="root", description="Redis username") + password: str = Field(default="", description="Redis password") + database: str = Field(default="memos_users", description="Redis database name") + charset: str = Field(default="utf8mb4", description="Redis charset") + + class UserManagerConfigFactory(BaseModel): """Factory for user manager configurations.""" @@ -42,6 +53,7 @@ class UserManagerConfigFactory(BaseModel): backend_to_class: ClassVar[dict[str, Any]] = { "sqlite": SQLiteUserManagerConfig, "mysql": MySQLUserManagerConfig, + "redis": RedisUserManagerConfig, } @field_validator("backend") diff --git a/src/memos/configs/memory.py b/src/memos/configs/memory.py index 1eea6deaf..237450e15 100644 --- a/src/memos/configs/memory.py +++ b/src/memos/configs/memory.py @@ -180,6 +180,10 @@ class TreeTextMemoryConfig(BaseTextMemoryConfig): ) +class SimpleTreeTextMemoryConfig(TreeTextMemoryConfig): + """Simple tree text memory configuration class.""" + + # ─── 3. Global Memory Config Factory ────────────────────────────────────────── @@ -192,6 +196,7 @@ class MemoryConfigFactory(BaseConfig): backend_to_class: ClassVar[dict[str, Any]] = { "naive_text": NaiveTextMemoryConfig, "general_text": GeneralTextMemoryConfig, + "simple_tree_text": SimpleTreeTextMemoryConfig, "tree_text": TreeTextMemoryConfig, "kv_cache": KVCacheMemoryConfig, "vllm_kv_cache": KVCacheMemoryConfig, # Use same config as kv_cache diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 66ad894ad..10c3c75d0 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -129,7 +129,6 @@ def _make_client_key(cfg: NebulaGraphDBConfig) -> str: "nebula-sync", ",".join(hosts), str(getattr(cfg, "user", "")), - str(getattr(cfg, "use_multi_db", False)), str(getattr(cfg, "space", "")), ] ) @@ -139,7 +138,7 @@ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> " tmp = object.__new__(NebulaGraphDB) tmp.config = cfg tmp.db_name = cfg.space - tmp.user_name = getattr(cfg, "user_name", None) + tmp.user_name = None tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072) tmp.default_memory_dimension = 3072 tmp.common_fields = { @@ -169,7 +168,7 @@ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> " if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension) else "embedding" ) - tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space + tmp.system_db_name = cfg.space tmp._client = client tmp._owns_client = False return tmp @@ -417,7 +416,9 @@ def create_index( self._create_basic_property_indexes() @timed - def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: + def remove_oldest_memory( + self, memory_type: str, keep_latest: int, user_name: str | None = None + ) -> None: """ Remove all WorkingMemory nodes except the latest `keep_latest` entries. @@ -426,9 +427,10 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: keep_latest (int): Number of latest WorkingMemory entries to keep. """ optional_condition = "" - if not self.config.use_multi_db and self.config.user_name: - optional_condition = f"AND n.user_name = '{self.config.user_name}'" + user_name = user_name if user_name else self.config.user_name + + optional_condition = f"AND n.user_name = '{user_name}'" query = f""" MATCH (n@Memory) WHERE n.memory_type = '{memory_type}' @@ -440,13 +442,13 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: self.execute_query(query) @timed - def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: + def add_node( + self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None + ) -> None: """ Insert or update a Memory node in NebulaGraph. """ - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name - + metadata["user_name"] = user_name if user_name else self.config.user_name now = datetime.utcnow() metadata = metadata.copy() metadata.setdefault("created_at", now) @@ -475,11 +477,9 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: ) @timed - def node_not_exist(self, scope: str) -> int: - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"' - else: - filter_clause = f'n.memory_type = "{scope}"' + def node_not_exist(self, scope: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name + filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"' query = f""" MATCH (n@Memory) WHERE {filter_clause} @@ -495,10 +495,11 @@ def node_not_exist(self, scope: str) -> int: raise @timed - def update_node(self, id: str, fields: dict[str, Any]) -> None: + def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: """ Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present. """ + user_name = user_name if user_name else self.config.user_name fields = fields.copy() set_clauses = [] for k, v in fields.items(): @@ -509,45 +510,41 @@ def update_node(self, id: str, fields: dict[str, Any]) -> None: query = f""" MATCH (n@Memory {{id: "{id}"}}) """ - - if not self.config.use_multi_db and self.config.user_name: - query += f'WHERE n.user_name = "{self.config.user_name}"' + query += f'WHERE n.user_name = "{user_name}"' query += f"\nSET {set_clause_str}" self.execute_query(query) @timed - def delete_node(self, id: str) -> None: + def delete_node(self, id: str, user_name: str | None = None) -> None: """ Delete a node from the graph. Args: id: Node identifier to delete. + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name query = f""" - MATCH (n@Memory {{id: "{id}"}}) + MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)} + DETACH DELETE n """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" WHERE n.user_name = {self._format_value(user_name)}" - query += "\n DETACH DELETE n" self.execute_query(query) @timed - def add_edge(self, source_id: str, target_id: str, type: str): + def add_edge(self, source_id: str, target_id: str, type: str, user_name: str | None = None): """ Create an edge from source node to target node. Args: source_id: ID of the source node. target_id: ID of the target node. type: Relationship type (e.g., 'RELATE_TO', 'PARENT'). + user_name (str, optional): User name for filtering in non-multi-db mode """ if not source_id or not target_id: raise ValueError("[add_edge] source_id and target_id must be provided") - + user_name = user_name if user_name else self.config.user_name props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' - + props = f'{{user_name: "{user_name}"}}' insert_stmt = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) INSERT (a) -[e@{type} {props}]-> (b) @@ -558,35 +555,35 @@ def add_edge(self, source_id: str, target_id: str, type: str): logger.error(f"Failed to insert edge: {e}", exc_info=True) @timed - def delete_edge(self, source_id: str, target_id: str, type: str) -> None: + def delete_edge( + self, source_id: str, target_id: str, type: str, user_name: str | None = None + ) -> None: """ Delete a specific edge between two nodes. Args: source_id: ID of the source node. target_id: ID of the target node. type: Relationship type to remove. + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name query = f""" MATCH (a@Memory) -[r@{type}]-> (b@Memory) WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)} """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" - + query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" query += "\nDELETE r" self.execute_query(query) @timed - def get_memory_count(self, memory_type: str) -> int: + def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name query = f""" MATCH (n@Memory) WHERE n.memory_type = "{memory_type}" """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + query += f"\nAND n.user_name = '{user_name}'" query += "\nRETURN COUNT(n) AS count" try: @@ -597,14 +594,13 @@ def get_memory_count(self, memory_type: str) -> int: return -1 @timed - def count_nodes(self, scope: str) -> int: + def count_nodes(self, scope: str, user_name: str | None = None) -> int: + user_name = user_name if user_name else self.config.user_name query = f""" MATCH (n@Memory) WHERE n.memory_type = "{scope}" """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + query += f"\nAND n.user_name = '{user_name}'" query += "\nRETURN count(n) AS count" result = self.execute_query(query) @@ -612,7 +608,12 @@ def count_nodes(self, scope: str) -> int: @timed def edge_exists( - self, source_id: str, target_id: str, type: str = "ANY", direction: str = "OUTGOING" + self, + source_id: str, + target_id: str, + type: str = "ANY", + direction: str = "OUTGOING", + user_name: str | None = None, ) -> bool: """ Check if an edge exists between two nodes. @@ -622,10 +623,12 @@ def edge_exists( type: Relationship type. Use "ANY" to match any relationship type. direction: Direction of the edge. Use "OUTGOING" (default), "INCOMING", or "ANY". + user_name (str, optional): User name for filtering in non-multi-db mode Returns: True if the edge exists, otherwise False. """ # Prepare the relationship pattern + user_name = user_name if user_name else self.config.user_name rel = "r" if type == "ANY" else f"r@{type}" # Prepare the match pattern with direction @@ -640,9 +643,7 @@ def edge_exists( f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." ) query = f"MATCH {pattern}" - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" query += "\nRETURN r" # Run the Cypher query @@ -654,22 +655,22 @@ def edge_exists( @timed # Graph Query & Reasoning - def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | None: + def get_node( + self, id: str, include_embedding: bool = False, user_name: str | None = None + ) -> dict[str, Any] | None: """ Retrieve a Memory node by its unique ID. Args: id (str): Node ID (Memory.id) include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: dict: Node properties as key-value pairs, or None if not found. """ - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.user_name = "{self.config.user_name}" AND n.id = "{id}"' - else: - filter_clause = f'n.id = "{id}"' - + user_name = user_name if user_name else self.config.user_name + filter_clause = f'n.user_name = "{user_name}" AND n.id = "{id}"' return_fields = self._build_return_fields(include_embedding) gql = f""" MATCH (n@Memory) @@ -692,13 +693,18 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | @timed def get_nodes( - self, ids: list[str], include_embedding: bool = False, **kwargs + self, + ids: list[str], + include_embedding: bool = False, + user_name: str | None = None, + **kwargs, ) -> list[dict[str, Any]]: """ Retrieve the metadata and memory of a list of nodes. Args: ids: List of Node identifier. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'. @@ -709,19 +715,14 @@ def get_nodes( if not ids: return [] - where_user = "" - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_user = f" AND n.user_name = '{kwargs['cube_name']}'" - else: - where_user = f" AND n.user_name = '{self.config.user_name}'" - + user_name = user_name if user_name else self.config.user_name + where_user = f" AND n.user_name = '{user_name}'" # Safe formatting of the ID list id_list = ",".join(f'"{_id}"' for _id in ids) return_fields = self._build_return_fields(include_embedding) query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE n.id IN [{id_list}] {where_user} RETURN {return_fields} """ @@ -738,7 +739,9 @@ def get_nodes( return nodes @timed - def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[dict[str, str]]: + def get_edges( + self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None + ) -> list[dict[str, str]]: """ Get edges connected to a node, with optional type and direction filter. @@ -746,6 +749,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ id: Node ID to retrieve edges for. type: Relationship type to match, or 'ANY' to match all. direction: 'OUTGOING', 'INCOMING', or 'ANY'. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: List of edges: @@ -756,7 +760,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ """ # Build relationship type filter rel_type = "" if type == "ANY" else f"@{type}" - + user_name = user_name if user_name else self.config.user_name # Build Cypher pattern based on direction if direction == "OUTGOING": pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)" @@ -770,8 +774,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ else: raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND a.user_name = '{self.config.user_name}' AND b.user_name = '{self.config.user_name}'" + where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" query = f""" MATCH {pattern} @@ -799,6 +802,7 @@ def get_neighbors_by_tag( top_k: int = 5, min_overlap: int = 1, include_embedding: bool = False, + user_name: str | None = None, ) -> list[dict[str, Any]]: """ Find top-K neighbor nodes with maximum tag overlap. @@ -809,13 +813,14 @@ def get_neighbors_by_tag( top_k: Max number of neighbors to return. min_overlap: Minimum number of overlapping tags required. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: List of dicts with node details and overlap count. """ if not tags: return [] - + user_name = user_name if user_name else self.config.user_name where_clauses = [ 'n.status = "activated"', 'NOT (n.node_type = "reasoning")', @@ -824,8 +829,7 @@ def get_neighbors_by_tag( if exclude_ids: where_clauses.append(f"NOT (n.id IN {exclude_ids})") - if not self.config.use_multi_db and self.config.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') where_clause = " AND ".join(where_clauses) tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]" @@ -859,12 +863,11 @@ def get_neighbors_by_tag( return result @timed - def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: - where_user = "" - - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" + def get_children_with_embeddings( + self, id: str, user_name: str | None = None + ) -> list[dict[str, Any]]: + user_name = user_name if user_name else self.config.user_name + where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" query = f""" MATCH (p@Memory)-[@PARENT]->(c@Memory) @@ -884,7 +887,11 @@ def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: @timed def get_subgraph( - self, center_id: str, depth: int = 2, center_status: str = "activated" + self, + center_id: str, + depth: int = 2, + center_status: str = "activated", + user_name: str | None = None, ) -> dict[str, Any]: """ Retrieve a local subgraph centered at a given node. @@ -892,6 +899,7 @@ def get_subgraph( center_id: The ID of the center node. depth: The hop distance for neighbors. center_status: Required status for center node. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: { "core_node": {...}, @@ -902,7 +910,8 @@ def get_subgraph( if not 1 <= depth <= 5: raise ValueError("depth must be 1-5") - user_name = self.config.user_name + user_name = user_name if user_name else self.config.user_name + gql = f""" MATCH (center@Memory) WHERE center.id = '{center_id}' @@ -954,6 +963,7 @@ def search_by_embedding( status: str | None = None, threshold: float | None = None, search_filter: dict | None = None, + user_name: str | None = None, **kwargs, ) -> list[dict]: """ @@ -968,6 +978,7 @@ def search_by_embedding( threshold (float, optional): Minimum similarity score threshold (0 ~ 1). search_filter (dict, optional): Additional metadata filters for search results. Keys should match node properties, values are the expected values. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. @@ -981,42 +992,35 @@ def search_by_embedding( - Typical use case: restrict to 'status = activated' to avoid matching archived or merged nodes. """ + user_name = user_name if user_name else self.config.user_name vector = _normalize(vector) dim = len(vector) vector_str = ",".join(f"{float(x)}" for x in vector) gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])" - - where_clauses = [] + where_clauses = [f"n.{self.dim_field} IS NOT NULL"] if scope: where_clauses.append(f'n.memory_type = "{scope}"') if status: where_clauses.append(f'n.status = "{status}"') - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"') - else: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append(f'n.{key} = "{value}"') - else: - where_clauses.append(f"n.{key} = {value}") + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append(f'n.{key} = "{value}"') + else: + where_clauses.append(f"n.{key} = {value}") where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" gql = f""" - MATCH (n@Memory) - {where_clause} - ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC - APPROXIMATE - LIMIT {top_k} - OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }} - RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score - """ - + let a = {gql_vector} + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) + {where_clause} + ORDER BY inner_product(n.{self.dim_field}, a) DESC + LIMIT {top_k} + RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score""" try: result = self.execute_query(gql) except Exception as e: @@ -1038,7 +1042,9 @@ def search_by_embedding( return [] @timed - def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: + def get_by_metadata( + self, filters: list[dict[str, Any]], user_name: str | None = None + ) -> list[str]: """ 1. ADD logic: "AND" vs "OR"(support logic combination); 2. Support nested conditional expressions; @@ -1054,6 +1060,7 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: {"field": "tags", "op": "contains", "value": "AI"}, ... ] + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[str]: Node IDs whose metadata match the filter conditions. (AND logic). @@ -1063,7 +1070,7 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: - Can be used for faceted recall or prefiltering before embedding rerank. """ where_clauses = [] - + user_name = user_name if user_name else self.config.user_name for _i, f in enumerate(filters): field = f["field"] op = f.get("op", "=") @@ -1087,11 +1094,10 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: else: raise ValueError(f"Unsupported operator: {op}") - if not self.config.use_multi_db and self.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{user_name}"') where_str = " AND ".join(where_clauses) - gql = f"MATCH (n@Memory) WHERE {where_str} RETURN n.id AS id" + gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id" ids = [] try: result = self.execute_query(gql) @@ -1106,6 +1112,7 @@ def get_grouped_counts( group_fields: list[str], where_clause: str = "", params: dict[str, Any] | None = None, + user_name: str | None = None, ) -> list[dict[str, Any]]: """ Count nodes grouped by any fields. @@ -1115,24 +1122,24 @@ def get_grouped_counts( where_clause (str, optional): Extra WHERE condition. E.g., "WHERE n.status = 'activated'" params (dict, optional): Parameters for WHERE clause. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] """ if not group_fields: raise ValueError("group_fields cannot be empty") - - # GQL-specific modifications - if not self.config.use_multi_db and self.config.user_name: - user_clause = f"n.user_name = '{self.config.user_name}'" - if where_clause: - where_clause = where_clause.strip() - if where_clause.upper().startswith("WHERE"): - where_clause += f" AND {user_clause}" - else: - where_clause = f"WHERE {where_clause} AND {user_clause}" + user_name = user_name if user_name else self.config.user_name + # GQL-specific modifications + user_clause = f"n.user_name = '{user_name}'" + if where_clause: + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause += f" AND {user_clause}" else: - where_clause = f"WHERE {user_clause}" + where_clause = f"WHERE {where_clause} AND {user_clause}" + else: + where_clause = f"WHERE {user_clause}" # Inline parameters if provided if params: @@ -1170,16 +1177,16 @@ def get_grouped_counts( return output @timed - def clear(self) -> None: + def clear(self, user_name: str | None = None) -> None: """ Clear the entire graph if the target database exists. + + Args: + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name try: - if not self.config.use_multi_db and self.config.user_name: - query = f"MATCH (n@Memory) WHERE n.user_name = '{self.config.user_name}' DETACH DELETE n" - else: - query = "MATCH (n) DETACH DELETE n" - + query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n" self.execute_query(query) logger.info("Cleared all nodes from database.") @@ -1187,11 +1194,14 @@ def clear(self) -> None: logger.error(f"[ERROR] Failed to clear database: {e}") @timed - def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: + def export_graph( + self, include_embedding: bool = False, user_name: str | None = None + ) -> dict[str, Any]: """ Export all graph nodes and edges in a structured form. Args: include_embedding (bool): Whether to include the large embedding field. + user_name (str, optional): User name for filtering in non-multi-db mode Returns: { @@ -1199,13 +1209,11 @@ def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] } """ + user_name = user_name if user_name else self.config.user_name node_query = "MATCH (n@Memory)" edge_query = "MATCH (a@Memory)-[r]->(b@Memory)" - - if not self.config.use_multi_db and self.config.user_name: - username = self.config.user_name - node_query += f' WHERE n.user_name = "{username}"' - edge_query += f' WHERE r.user_name = "{username}"' + node_query += f' WHERE n.user_name = "{user_name}"' + edge_query += f' WHERE r.user_name = "{user_name}"' try: if include_embedding: @@ -1265,20 +1273,19 @@ def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: return {"nodes": nodes, "edges": edges} @timed - def import_graph(self, data: dict[str, Any]) -> None: + def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: """ Import the entire graph from a serialized dictionary. Args: data: A dictionary containing all nodes and edges to be loaded. + user_name (str, optional): User name for filtering in non-multi-db mode """ + user_name = user_name if user_name else self.config.user_name for node in data.get("nodes", []): try: id, memory, metadata = _compose_node(node) - - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name - + metadata["user_name"] = user_name metadata = self._prepare_node_metadata(metadata) metadata.update({"id": id, "memory": memory}) properties = ", ".join( @@ -1293,9 +1300,7 @@ def import_graph(self, data: dict[str, Any]) -> None: try: source_id, target_id = edge["source"], edge["target"] edge_type = edge["type"] - props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' + props = f'{{user_name: "{user_name}"}}' edge_gql = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b) @@ -1305,29 +1310,31 @@ def import_graph(self, data: dict[str, Any]) -> None: logger.error(f"Fail to load edge: {edge}, error: {e}") @timed - def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> (list)[dict]: + def get_all_memory_items( + self, scope: str, include_embedding: bool = False, user_name: str | None = None + ) -> (list)[dict]: """ Retrieve all memory items of a specific memory_type. Args: scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. include_embedding: with/without embedding + user_name (str, optional): User name for filtering in non-multi-db mode Returns: list[dict]: Full list of memory items under this scope. """ + user_name = user_name if user_name else self.config.user_name if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: raise ValueError(f"Unsupported memory type scope: {scope}") where_clause = f"WHERE n.memory_type = '{scope}'" - - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND n.user_name = '{self.config.user_name}'" + where_clause += f" AND n.user_name = '{user_name}'" return_fields = self._build_return_fields(include_embedding) query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {return_fields} LIMIT 100 @@ -1344,20 +1351,19 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> ( @timed def get_structure_optimization_candidates( - self, scope: str, include_embedding: bool = False + self, scope: str, include_embedding: bool = False, user_name: str | None = None ) -> list[dict]: """ Find nodes that are likely candidates for structure optimization: - Isolated nodes, nodes with empty background, or nodes with exactly one child. - Plus: the child of any parent node that has exactly one child. """ - + user_name = user_name if user_name else self.config.user_name where_clause = f''' n.memory_type = "{scope}" AND n.status = "activated" ''' - if not self.config.use_multi_db and self.config.user_name: - where_clause += f' AND n.user_name = "{self.config.user_name}"' + where_clause += f' AND n.user_name = "{user_name}"' return_fields = self._build_return_fields(include_embedding) return_fields += f", n.{self.dim_field} AS {self.dim_field}" @@ -1386,21 +1392,6 @@ def get_structure_optimization_candidates( logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}") return candidates - @timed - def drop_database(self) -> None: - """ - Permanently delete the entire database this instance is using. - WARNING: This operation is destructive and cannot be undone. - """ - if self.config.use_multi_db: - self.execute_query(f"DROP GRAPH `{self.db_name}`") - logger.info(f"Database '`{self.db_name}`' has been dropped.") - else: - raise ValueError( - f"Refusing to drop protected database: `{self.db_name}` in " - f"Shared Database Multi-Tenant mode" - ) - @timed def detect_conflicts(self) -> list[tuple[str, str]]: """ @@ -1585,9 +1576,7 @@ def _create_basic_property_indexes(self) -> None: Create standard B-tree indexes on user_name when use Shared Database Multi-Tenant Mode. """ - fields = ["status", "memory_type", "created_at", "updated_at"] - if not self.config.use_multi_db: - fields.append("user_name") + fields = ["status", "memory_type", "created_at", "updated_at", "user_name"] for field in fields: index_name = f"idx_memory_{field}" diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py new file mode 100644 index 000000000..7ce3ca642 --- /dev/null +++ b/src/memos/mem_cube/navie.py @@ -0,0 +1,166 @@ +import os + +from typing import Literal + +from memos.configs.mem_cube import GeneralMemCubeConfig +from memos.configs.utils import get_json_file_model_schema +from memos.embedders.base import BaseEmbedder +from memos.exceptions import ConfigurationError, MemCubeError +from memos.graph_dbs.base import BaseGraphDB +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.mem_cube.base import BaseMemCube +from memos.mem_reader.base import BaseMemReader +from memos.memories.activation.base import BaseActMemory +from memos.memories.parametric.base import BaseParaMemory +from memos.memories.textual.base import BaseTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.reranker.base import BaseReranker + + +logger = get_logger(__name__) + + +class NaiveMemCube(BaseMemCube): + """MemCube is a box for loading and dumping three types of memories.""" + + def __init__( + self, + llm: BaseLLM, + embedder: BaseEmbedder, + mem_reader: BaseMemReader, + graph_db: BaseGraphDB, + reranker: BaseReranker, + memory_manager: MemoryManager, + default_cube_config: GeneralMemCubeConfig, + internet_retriever: None = None, + ): + """Initialize the MemCube with a configuration.""" + self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory( + llm, + embedder, + mem_reader, + graph_db, + reranker, + memory_manager, + default_cube_config.text_mem.config, + internet_retriever, + ) + self._act_mem: BaseActMemory | None = None + self._para_mem: BaseParaMemory | None = None + + def load( + self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + ) -> None: + """Load memories. + Args: + dir (str): The directory containing the memory files. + memory_types (list[str], optional): List of memory types to load. + If None, loads all available memory types. + Options: ["text_mem", "act_mem", "para_mem"] + """ + loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename)) + if loaded_schema != self.config.model_schema: + raise ConfigurationError( + f"Configuration schema mismatch. Expected {self.config.model_schema}, " + f"but found {loaded_schema}." + ) + + # If no specific memory types specified, load all + if memory_types is None: + memory_types = ["text_mem", "act_mem", "para_mem"] + + # Load specified memory types + if "text_mem" in memory_types and self.text_mem: + self.text_mem.load(dir) + logger.debug(f"Loaded text_mem from {dir}") + + if "act_mem" in memory_types and self.act_mem: + self.act_mem.load(dir) + logger.info(f"Loaded act_mem from {dir}") + + if "para_mem" in memory_types and self.para_mem: + self.para_mem.load(dir) + logger.info(f"Loaded para_mem from {dir}") + + logger.info(f"MemCube loaded successfully from {dir} (types: {memory_types})") + + def dump( + self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + ) -> None: + """Dump memories. + Args: + dir (str): The directory where the memory files will be saved. + memory_types (list[str], optional): List of memory types to dump. + If None, dumps all available memory types. + Options: ["text_mem", "act_mem", "para_mem"] + """ + if os.path.exists(dir) and os.listdir(dir): + raise MemCubeError( + f"Directory {dir} is not empty. Please provide an empty directory for dumping." + ) + + # Always dump config + self.config.to_json_file(os.path.join(dir, self.config.config_filename)) + + # If no specific memory types specified, dump all + if memory_types is None: + memory_types = ["text_mem", "act_mem", "para_mem"] + + # Dump specified memory types + if "text_mem" in memory_types and self.text_mem: + self.text_mem.dump(dir) + logger.info(f"Dumped text_mem to {dir}") + + if "act_mem" in memory_types and self.act_mem: + self.act_mem.dump(dir) + logger.info(f"Dumped act_mem to {dir}") + + if "para_mem" in memory_types and self.para_mem: + self.para_mem.dump(dir) + logger.info(f"Dumped para_mem to {dir}") + + logger.info(f"MemCube dumped successfully to {dir} (types: {memory_types})") + + @property + def text_mem(self) -> "BaseTextMemory | None": + """Get the textual memory.""" + if self._text_mem is None: + logger.warning("Textual memory is not initialized. Returning None.") + return self._text_mem + + @text_mem.setter + def text_mem(self, value: BaseTextMemory) -> None: + """Set the textual memory.""" + if not isinstance(value, BaseTextMemory): + raise TypeError(f"Expected BaseTextMemory, got {type(value).__name__}") + self._text_mem = value + + @property + def act_mem(self) -> "BaseActMemory | None": + """Get the activation memory.""" + if self._act_mem is None: + logger.warning("Activation memory is not initialized. Returning None.") + return self._act_mem + + @act_mem.setter + def act_mem(self, value: BaseActMemory) -> None: + """Set the activation memory.""" + if not isinstance(value, BaseActMemory): + raise TypeError(f"Expected BaseActMemory, got {type(value).__name__}") + self._act_mem = value + + @property + def para_mem(self) -> "BaseParaMemory | None": + """Get the parametric memory.""" + if self._para_mem is None: + logger.warning("Parametric memory is not initialized. Returning None.") + return self._para_mem + + @para_mem.setter + def para_mem(self, value: BaseParaMemory) -> None: + """Set the parametric memory.""" + if not isinstance(value, BaseParaMemory): + raise TypeError(f"Expected BaseParaMemory, got {type(value).__name__}") + self._para_mem = value diff --git a/src/memos/mem_user/persistent_factory.py b/src/memos/mem_user/persistent_factory.py index b5ece61b5..6a7b4fa13 100644 --- a/src/memos/mem_user/persistent_factory.py +++ b/src/memos/mem_user/persistent_factory.py @@ -3,6 +3,7 @@ from memos.configs.mem_user import UserManagerConfigFactory from memos.mem_user.mysql_persistent_user_manager import MySQLPersistentUserManager from memos.mem_user.persistent_user_manager import PersistentUserManager +from memos.mem_user.redis_persistent_user_manager import RedisPersistentUserManager class PersistentUserManagerFactory: @@ -11,6 +12,7 @@ class PersistentUserManagerFactory: backend_to_class: ClassVar[dict[str, Any]] = { "sqlite": PersistentUserManager, "mysql": MySQLPersistentUserManager, + "redis": RedisPersistentUserManager, } @classmethod diff --git a/src/memos/mem_user/redis_persistent_user_manager.py b/src/memos/mem_user/redis_persistent_user_manager.py new file mode 100644 index 000000000..48c89c663 --- /dev/null +++ b/src/memos/mem_user/redis_persistent_user_manager.py @@ -0,0 +1,225 @@ +"""Redis-based persistent user management system for MemOS with configuration storage. + +This module provides persistent storage for user configurations using Redis. +""" + +import json + +from memos.configs.mem_os import MOSConfig +from memos.dependency import require_python_package +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class RedisPersistentUserManager: + """Redis-based user configuration manager with persistence.""" + + @require_python_package( + import_name="redis", + install_command="pip install redis", + install_link="https://redis.readthedocs.io/en/stable/", + ) + def __init__( + self, + host: str = "localhost", + port: int = 6379, + password: str = "", + db: int = 0, + decode_responses: bool = True, + ): + """Initialize the Redis persistent user manager. + + Args: + user_id (str, optional): User ID. Defaults to "root". + host (str): Redis server host. Defaults to "localhost". + port (int): Redis server port. Defaults to 6379. + password (str): Redis password. Defaults to "". + db (int): Redis database number. Defaults to 0. + decode_responses (bool): Whether to decode responses to strings. Defaults to True. + """ + import redis + + self.host = host + self.port = port + self.db = db + + try: + # Create Redis connection + self._redis_client = redis.Redis( + host=host, + port=port, + password=password if password else None, + db=db, + decode_responses=decode_responses, + ) + + # Test connection + if not self._redis_client.ping(): + raise ConnectionError("Redis connection failed") + + logger.info( + f"RedisPersistentUserManager initialized successfully, connected to {host}:{port}/{db}" + ) + + except Exception as e: + logger.error(f"Redis connection error: {e}") + raise + + def _get_config_key(self, user_id: str) -> str: + """Generate Redis key for user configuration. + + Args: + user_id (str): User ID. + + Returns: + str: Redis key name. + """ + return user_id + + def save_user_config(self, user_id: str, config: MOSConfig) -> bool: + """Save user configuration to Redis. + + Args: + user_id (str): User ID. + config (MOSConfig): User's MOS configuration. + + Returns: + bool: True if successful, False otherwise. + """ + try: + # Convert config to JSON string + config_dict = config.model_dump(mode="json") + config_json = json.dumps(config_dict, ensure_ascii=False, indent=2) + + # Save to Redis + key = self._get_config_key(user_id) + self._redis_client.set(key, config_json) + + logger.info(f"Successfully saved configuration for user {user_id} to Redis") + return True + + except Exception as e: + logger.error(f"Error saving configuration for user {user_id}: {e}") + return False + + def get_user_config(self, user_id: str) -> dict | None: + """Get user configuration from Redis (search interface). + + Args: + user_id (str): User ID. + + Returns: + MOSConfig | None: User's configuration object, or None if not found. + """ + try: + # Get configuration from Redis + key = self._get_config_key(user_id) + config_json = self._redis_client.get(key) + + if config_json is None: + logger.info(f"Configuration for user {user_id} does not exist") + return None + + # Parse JSON and create MOSConfig object + config_dict = json.loads(config_json) + + logger.info(f"Successfully retrieved configuration for user {user_id}") + return config_dict + + except json.JSONDecodeError as e: + logger.error(f"Error parsing JSON configuration for user {user_id}: {e}") + return None + except Exception as e: + logger.error(f"Error retrieving configuration for user {user_id}: {e}") + return None + + def delete_user_config(self, user_id: str) -> bool: + """Delete user configuration from Redis. + + Args: + user_id (str): User ID. + + Returns: + bool: True if successful, False otherwise. + """ + try: + key = self._get_config_key(user_id) + result = self._redis_client.delete(key) + + if result > 0: + logger.info(f"Successfully deleted configuration for user {user_id}") + return True + else: + logger.warning(f"Configuration for user {user_id} does not exist, cannot delete") + return False + + except Exception as e: + logger.error(f"Error deleting configuration for user {user_id}: {e}") + return False + + def exists_user_config(self, user_id: str) -> bool: + """Check if user configuration exists. + + Args: + user_id (str): User ID. + + Returns: + bool: True if exists, False otherwise. + """ + try: + key = self._get_config_key(user_id) + return self._redis_client.exists(key) > 0 + except Exception as e: + logger.error(f"Error checking if configuration exists for user {user_id}: {e}") + return False + + def list_user_configs( + self, pattern: str = "user_config:*", count: int = 100 + ) -> dict[str, dict]: + """List all user configurations. + + Args: + pattern (str): Redis key matching pattern. Defaults to "user_config:*". + count (int): Number of keys to return per scan. Defaults to 100. + + Returns: + dict[str, dict]: Dictionary mapping user_id to dict objects. + """ + result = {} + try: + # Use SCAN command to iterate through all matching keys + cursor = 0 + while True: + cursor, keys = self._redis_client.scan(cursor, match=pattern, count=count) + + for key in keys: + # Extract user_id (remove "user_config:" prefix) + user_id = key.replace("user_config:", "") + config = self.get_user_config(user_id) + if config: + result[user_id] = config + + if cursor == 0: + break + + logger.info(f"Successfully listed {len(result)} user configurations") + return result + + except Exception as e: + logger.error(f"Error listing user configurations: {e}") + return {} + + def close(self) -> None: + """Close Redis connection. + + This method should be called when the RedisPersistentUserManager is no longer needed + to ensure proper cleanup of Redis connections. + """ + try: + if hasattr(self, "_redis_client") and self._redis_client: + self._redis_client.close() + logger.info("Redis connection closed") + except Exception as e: + logger.error(f"Error closing Redis connection: {e}") diff --git a/src/memos/memories/factory.py b/src/memos/memories/factory.py index 9fdc67c53..bcf7fdd9b 100644 --- a/src/memos/memories/factory.py +++ b/src/memos/memories/factory.py @@ -10,6 +10,7 @@ from memos.memories.textual.base import BaseTextMemory from memos.memories.textual.general import GeneralTextMemory from memos.memories.textual.naive import NaiveTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree import TreeTextMemory @@ -20,6 +21,7 @@ class MemoryFactory(BaseMemory): "naive_text": NaiveTextMemory, "general_text": GeneralTextMemory, "tree_text": TreeTextMemory, + "simple_tree_text": SimpleTreeTextMemory, "kv_cache": KVCacheMemory, "vllm_kv_cache": VLLMKVCacheMemory, "lora": LoRAMemory, diff --git a/src/memos/memories/textual/base.py b/src/memos/memories/textual/base.py index 8171fadce..82dad4486 100644 --- a/src/memos/memories/textual/base.py +++ b/src/memos/memories/textual/base.py @@ -24,7 +24,7 @@ def extract(self, messages: MessageList) -> list[TextualMemoryItem]: """ @abstractmethod - def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: + def add(self, memories: list[TextualMemoryItem | dict[str, Any]], **kwargs) -> list[str]: """Add memories. Args: diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py new file mode 100644 index 000000000..9c67db288 --- /dev/null +++ b/src/memos/memories/textual/simple_tree.py @@ -0,0 +1,295 @@ +import time + +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from memos.configs.memory import TreeTextMemoryConfig +from memos.embedders.base import BaseEmbedder +from memos.graph_dbs.base import BaseGraphDB +from memos.llms.base import BaseLLM +from memos.log import get_logger +from memos.mem_reader.base import BaseMemReader +from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.memories.textual.tree import TreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.reranker.base import BaseReranker +from memos.types import MessageList + + +if TYPE_CHECKING: + from memos.embedders.factory import OllamaEmbedder + from memos.graph_dbs.factory import Neo4jGraphDB + from memos.llms.factory import AzureLLM, OllamaLLM, OpenAILLM + + +logger = get_logger(__name__) + + +class SimpleTreeTextMemory(TreeTextMemory): + """General textual memory implementation for storing and retrieving memories.""" + + def __init__( + self, + llm: BaseLLM, + embedder: BaseEmbedder, + mem_reader: BaseMemReader, + graph_db: BaseGraphDB, + reranker: BaseReranker, + memory_manager: MemoryManager, + config: TreeTextMemoryConfig, + internet_retriever: None = None, + is_reorganize: bool = False, + ): + """Initialize memory with the given configuration.""" + time_start = time.time() + self.config: TreeTextMemoryConfig = config + + self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = llm + logger.info(f"time init: extractor_llm time is: {time.time() - time_start}") + + time_start_ex = time.time() + self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = llm + logger.info(f"time init: dispatcher_llm time is: {time.time() - time_start_ex}") + + time_start_em = time.time() + self.embedder: OllamaEmbedder = embedder + logger.info(f"time init: embedder time is: {time.time() - time_start_em}") + + time_start_gs = time.time() + self.graph_store: Neo4jGraphDB = graph_db + logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}") + + time_start_rr = time.time() + self.reranker = reranker + logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") + + time_start_mm = time.time() + self.memory_manager: MemoryManager = memory_manager + logger.info(f"time init: memory_manager time is: {time.time() - time_start_mm}") + time_start_ir = time.time() + # Create internet retriever if configured + self.internet_retriever = None + if config.internet_retriever is not None: + self.internet_retriever = internet_retriever + logger.info( + f"Internet retriever initialized with backend: {config.internet_retriever.backend}" + ) + else: + logger.info("No internet retriever configured") + logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}") + + def add( + self, memories: list[TextualMemoryItem | dict[str, Any]], user_name: str | None = None + ) -> list[str]: + """Add memories. + Args: + memories: List of TextualMemoryItem objects or dictionaries to add. + Later: + memory_items = [TextualMemoryItem(**m) if isinstance(m, dict) else m for m in memories] + metadata = extract_metadata(memory_items, self.extractor_llm) + plan = plan_memory_operations(memory_items, metadata, self.graph_store) + execute_plan(memory_items, metadata, plan, self.graph_store) + """ + return self.memory_manager.add(memories, user_name=user_name) + + def replace_working_memory( + self, memories: list[TextualMemoryItem], user_name: str | None = None + ) -> None: + self.memory_manager.replace_working_memory(memories, user_name=user_name) + + def get_working_memory(self, user_name: str | None = None) -> list[TextualMemoryItem]: + working_memories = self.graph_store.get_all_memory_items( + scope="WorkingMemory", user_name=user_name + ) + items = [TextualMemoryItem.from_dict(record) for record in (working_memories)] + # Sort by updated_at in descending order + sorted_items = sorted( + items, key=lambda x: x.metadata.updated_at or datetime.min, reverse=True + ) + return sorted_items + + def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int]: + """ + Get the current size of each memory type. + This delegates to the MemoryManager. + """ + return self.memory_manager.get_current_memory_size(user_name=user_name) + + def search( + self, + query: str, + top_k: int, + info=None, + mode: str = "fast", + memory_type: str = "All", + manual_close_internet: bool = False, + moscube: bool = False, + search_filter: dict | None = None, + user_name: str | None = None, + ) -> list[TextualMemoryItem]: + """Search for memories based on a query. + User query -> TaskGoalParser -> MemoryPathResolver -> + GraphMemoryRetriever -> MemoryReranker -> MemoryReasoner -> Final output + Args: + query (str): The query to search for. + top_k (int): The number of top results to return. + info (dict): Leave a record of memory consumption. + mode (str, optional): The mode of the search. + - 'fast': Uses a faster search process, sacrificing some precision for speed. + - 'fine': Uses a more detailed search process, invoking large models for higher precision, but slower performance. + memory_type (str): Type restriction for search. + ['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory'] + manual_close_internet (bool): If True, the internet retriever will be closed by this search, it high priority than config. + moscube (bool): whether you use moscube to answer questions + search_filter (dict, optional): Optional metadata filters for search results. + - Keys correspond to memory metadata fields (e.g., "user_id", "session_id"). + - Values are exact-match conditions. + Example: {"user_id": "123", "session_id": "abc"} + If None, no additional filtering is applied. + Returns: + list[TextualMemoryItem]: List of matching memories. + """ + if (self.internet_retriever is not None) and manual_close_internet: + logger.warning( + "Internet retriever is init by config , but this search set manual_close_internet is True and will close it" + ) + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=None, + moscube=moscube, + ) + else: + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=self.internet_retriever, + moscube=moscube, + ) + return searcher.search( + query, top_k, info, mode, memory_type, search_filter, user_name=user_name + ) + + def get_relevant_subgraph( + self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated" + ) -> dict[str, Any]: + """ + Find and merge the local neighborhood sub-graphs of the top-k + nodes most relevant to the query. + Process: + 1. Embed the user query into a vector representation. + 2. Use vector similarity search to find the top-k similar nodes. + 3. For each similar node: + - Ensure its status matches `center_status` (e.g., 'active'). + - Retrieve its local subgraph up to `depth` hops. + - Collect the center node, its neighbors, and connecting edges. + 4. Merge all retrieved subgraphs into a single unified subgraph. + 5. Return the merged subgraph structure. + + Args: + query (str): The user input or concept to find relevant memories for. + top_k (int, optional): How many top similar nodes to retrieve. Default is 5. + depth (int, optional): The neighborhood depth (number of hops). Default is 2. + center_status (str, optional): Status condition the center node must satisfy (e.g., 'active'). + + Returns: + dict[str, Any]: A subgraph dict with: + - 'core_id': ID of the top matching core node, or None if none found. + - 'nodes': List of unique nodes (core + neighbors) in the merged subgraph. + - 'edges': List of unique edges (as dicts with 'from', 'to', 'type') in the merged subgraph. + """ + # Step 1: Embed query + query_embedding = self.embedder.embed([query])[0] + + # Step 2: Get top-1 similar node + similar_nodes = self.graph_store.search_by_embedding(query_embedding, top_k=top_k) + if not similar_nodes: + logger.info("No similar nodes found for query embedding.") + return {"core_id": None, "nodes": [], "edges": []} + + # Step 3: Fetch neighborhood + all_nodes = {} + all_edges = set() + cores = [] + + for node in similar_nodes: + core_id = node["id"] + score = node["score"] + + subgraph = self.graph_store.get_subgraph( + center_id=core_id, depth=depth, center_status=center_status + ) + + if not subgraph["core_node"]: + logger.info(f"Skipping node {core_id} (inactive or not found).") + continue + + core_node = subgraph["core_node"] + neighbors = subgraph["neighbors"] + edges = subgraph["edges"] + + # Collect nodes + all_nodes[core_node["id"]] = core_node + for n in neighbors: + all_nodes[n["id"]] = n + + # Collect edges + for e in edges: + all_edges.add((e["source"], e["target"], e["type"])) + + cores.append( + {"id": core_id, "score": score, "core_node": core_node, "neighbors": neighbors} + ) + + top_core = cores[0] + return { + "core_id": top_core["id"], + "nodes": list(all_nodes.values()), + "edges": [{"source": f, "target": t, "type": ty} for (f, t, ty) in all_edges], + } + + def extract(self, messages: MessageList) -> list[TextualMemoryItem]: + raise NotImplementedError + + def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: + raise NotImplementedError + + def get(self, memory_id: str) -> TextualMemoryItem: + """Get a memory by its ID.""" + result = self.graph_store.get_node(memory_id) + if result is None: + raise ValueError(f"Memory with ID {memory_id} not found") + metadata_dict = result.get("metadata", {}) + return TextualMemoryItem( + id=result["id"], + memory=result["memory"], + metadata=TreeNodeTextualMemoryMetadata(**metadata_dict), + ) + + def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: + raise NotImplementedError + + def get_all(self) -> dict: + """Get all memories. + Returns: + list[TextualMemoryItem]: List of all memories. + """ + all_items = self.graph_store.export_graph() + return all_items + + def delete(self, memory_ids: list[str]) -> None: + raise NotImplementedError + + def delete_all(self) -> None: + """Delete all memories and their relationships from the graph store.""" + try: + self.graph_store.clear() + logger.info("All memories and edges have been deleted from the graph.") + except Exception as e: + logger.error(f"An error occurred while deleting all memories: {e}") + raise diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index c9cd4de8a..680052a9d 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -51,14 +51,14 @@ def __init__( ) self._merged_threshold = merged_threshold - def add(self, memories: list[TextualMemoryItem]) -> list[str]: + def add(self, memories: list[TextualMemoryItem], user_name: str | None = None) -> list[str]: """ Add new memories in parallel to different memory types (WorkingMemory, LongTermMemory, UserMemory). """ added_ids: list[str] = [] with ContextThreadPoolExecutor(max_workers=8) as executor: - futures = {executor.submit(self._process_memory, m): m for m in memories} + futures = {executor.submit(self._process_memory, m, user_name): m for m in memories} for future in as_completed(futures, timeout=60): try: ids = future.result() @@ -66,38 +66,31 @@ def add(self, memories: list[TextualMemoryItem]) -> list[str]: except Exception as e: logger.exception("Memory processing error: ", exc_info=e) - try: - self.graph_store.remove_oldest_memory( - memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"] - ) - except Exception: - logger.warning(f"Remove WorkingMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"] - ) - except Exception: - logger.warning(f"Remove LongTermMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"] - ) - except Exception: - logger.warning(f"Remove UserMemory error: {traceback.format_exc()}") + for mem_type in ["WorkingMemory", "LongTermMemory", "UserMemory"]: + try: + self.graph_store.remove_oldest_memory( + memory_type="WorkingMemory", + keep_latest=self.memory_size[mem_type], + user_name=user_name, + ) + except Exception: + logger.warning(f"Remove {mem_type} error: {traceback.format_exc()}") - self._refresh_memory_size() + self._refresh_memory_size(user_name=user_name) return added_ids - def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: + def replace_working_memory( + self, memories: list[TextualMemoryItem], user_name: str | None = None + ) -> None: """ Replace WorkingMemory """ working_memory_top_k = memories[: self.memory_size["WorkingMemory"]] with ContextThreadPoolExecutor(max_workers=8) as executor: futures = [ - executor.submit(self._add_memory_to_db, memory, "WorkingMemory") + executor.submit( + self._add_memory_to_db, memory, "WorkingMemory", user_name=user_name + ) for memory in working_memory_top_k ] for future in as_completed(futures, timeout=60): @@ -107,47 +100,51 @@ def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: logger.exception("Memory processing error: ", exc_info=e) self.graph_store.remove_oldest_memory( - memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"] + memory_type="WorkingMemory", + keep_latest=self.memory_size["WorkingMemory"], + user_name=user_name, ) - self._refresh_memory_size() + self._refresh_memory_size(user_name=user_name) - def get_current_memory_size(self) -> dict[str, int]: + def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int]: """ Return the cached memory type counts. """ - self._refresh_memory_size() + self._refresh_memory_size(user_name=user_name) return self.current_memory_size - def _refresh_memory_size(self) -> None: + def _refresh_memory_size(self, user_name: str | None = None) -> None: """ Query the latest counts from the graph store and update internal state. """ - results = self.graph_store.get_grouped_counts(group_fields=["memory_type"]) + results = self.graph_store.get_grouped_counts( + group_fields=["memory_type"], user_name=user_name + ) self.current_memory_size = {record["memory_type"]: record["count"] for record in results} logger.info(f"[MemoryManager] Refreshed memory sizes: {self.current_memory_size}") - def _process_memory(self, memory: TextualMemoryItem): + def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = None): """ Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory). This method runs asynchronously to process each memory item. """ ids = [] - # Add to WorkingMemory - working_id = self._add_memory_to_db(memory, "WorkingMemory") - ids.append(working_id) + # Add to WorkingMemory do not return working_id + self._add_memory_to_db(memory, "WorkingMemory", user_name) # Add to LongTermMemory and UserMemory if memory.metadata.memory_type in ["LongTermMemory", "UserMemory"]: added_id = self._add_to_graph_memory( - memory=memory, - memory_type=memory.metadata.memory_type, + memory=memory, memory_type=memory.metadata.memory_type, user_name=user_name ) ids.append(added_id) return ids - def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str) -> str: + def _add_memory_to_db( + self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + ) -> str: """ Add a single memory item to the graph store, with FIFO logic for WorkingMemory. """ @@ -158,10 +155,12 @@ def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str) -> str: working_memory = TextualMemoryItem(memory=memory.memory, metadata=metadata) # Insert node into graph - self.graph_store.add_node(working_memory.id, working_memory.memory, metadata) + self.graph_store.add_node(working_memory.id, working_memory.memory, metadata, user_name) return working_memory.id - def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str): + def _add_to_graph_memory( + self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + ): """ Generalized method to add memory to a graph-based memory type (e.g., LongTermMemory, UserMemory). @@ -175,7 +174,10 @@ def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str): node_id = str(uuid.uuid4()) # Step 2: Add new node to graph self.graph_store.add_node( - node_id, memory.memory, memory.metadata.model_dump(exclude_none=True) + node_id, + memory.memory, + memory.metadata.model_dump(exclude_none=True), + user_name=user_name, ) self.reorganizer.add_message( QueueMessage( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 84cc8ecb3..d4cfcf501 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -30,6 +30,7 @@ def retrieve( memory_scope: str, query_embedding: list[list[float]] | None = None, search_filter: dict | None = None, + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -53,13 +54,13 @@ def retrieve( if memory_scope == "WorkingMemory": # For working memory, retrieve all entries (no filtering) working_memories = self.graph_store.get_all_memory_items( - scope="WorkingMemory", include_embedding=False + scope="WorkingMemory", include_embedding=False, user_name=user_name ) return [TextualMemoryItem.from_dict(record) for record in working_memories] with ContextThreadPoolExecutor(max_workers=2) as executor: # Structured graph-based retrieval - future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope) + future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name) # Vector similarity search future_vector = executor.submit( self._vector_recall, @@ -67,6 +68,7 @@ def retrieve( memory_scope, top_k, search_filter=search_filter, + user_name=user_name, ) graph_results = future_graph.result() @@ -92,6 +94,7 @@ def retrieve_from_cube( memory_scope: str, query_embedding: list[list[float]] | None = None, cube_name: str = "memos_cube01", + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -112,7 +115,7 @@ def retrieve_from_cube( raise ValueError(f"Unsupported memory scope: {memory_scope}") graph_results = self._vector_recall( - query_embedding, memory_scope, top_k, cube_name=cube_name + query_embedding, memory_scope, top_k, cube_name=cube_name, user_name=user_name ) for result_i in graph_results: @@ -132,7 +135,7 @@ def retrieve_from_cube( return list(combined.values()) def _graph_recall( - self, parsed_goal: ParsedTaskGoal, memory_scope: str + self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None ) -> list[TextualMemoryItem]: """ Perform structured node-based retrieval from Neo4j. @@ -148,7 +151,7 @@ def _graph_recall( {"field": "key", "op": "in", "value": parsed_goal.keys}, {"field": "memory_type", "op": "=", "value": memory_scope}, ] - key_ids = self.graph_store.get_by_metadata(key_filters) + key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) candidate_ids.update(key_ids) # 2) tag-based OR branch @@ -157,7 +160,7 @@ def _graph_recall( {"field": "tags", "op": "contains", "value": parsed_goal.tags}, {"field": "memory_type", "op": "=", "value": memory_scope}, ] - tag_ids = self.graph_store.get_by_metadata(tag_filters) + tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name) candidate_ids.update(tag_ids) # No matches → return empty @@ -165,7 +168,9 @@ def _graph_recall( return [] # Load nodes and post-filter - node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + node_dicts = self.graph_store.get_nodes( + list(candidate_ids), include_embedding=False, user_name=user_name + ) final_nodes = [] for node in node_dicts: @@ -194,6 +199,7 @@ def _vector_recall( max_num: int = 3, cube_name: str | None = None, search_filter: dict | None = None, + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Perform vector-based similarity retrieval using query embedding. @@ -210,6 +216,7 @@ def search_single(vec, filt=None): scope=memory_scope, cube_name=cube_name, search_filter=filt, + user_name=user_name, ) or [] ) @@ -255,7 +262,7 @@ def search_path_b(): unique_ids = {r["id"] for r in all_hits if r.get("id")} node_dicts = ( self.graph_store.get_nodes( - list(unique_ids), include_embedding=False, cube_name=cube_name + list(unique_ids), include_embedding=False, cube_name=cube_name, user_name=user_name ) or [] ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index df154f23a..05db56f53 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -12,7 +12,6 @@ from memos.reranker.base import BaseReranker from memos.utils import timed -from .internet_retriever_factory import InternetRetrieverFactory from .reasoner import MemoryReasoner from .recall import GraphMemoryRetriever from .task_goal_parser import TaskGoalParser @@ -28,7 +27,7 @@ def __init__( graph_store: Neo4jGraphDB, embedder: OllamaEmbedder, reranker: BaseReranker, - internet_retriever: InternetRetrieverFactory | None = None, + internet_retriever: None = None, moscube: bool = False, ): self.graph_store = graph_store @@ -54,6 +53,7 @@ def search( mode="fast", memory_type="All", search_filter: dict | None = None, + user_name: str | None = None, ) -> list[TextualMemoryItem]: """ Search for memories based on a query. @@ -85,14 +85,22 @@ def search( logger.debug(f"[SEARCH] Received info dict: {info}") parsed_goal, query_embedding, context, query = self._parse_task( - query, info, mode, search_filter=search_filter + query, info, mode, search_filter=search_filter, user_name=user_name ) results = self._retrieve_paths( - query, parsed_goal, query_embedding, info, top_k, mode, memory_type, search_filter + query, + parsed_goal, + query_embedding, + info, + top_k, + mode, + memory_type, + search_filter, + user_name, ) deduped = self._deduplicate_results(results) final_results = self._sort_and_trim(deduped, top_k) - self._update_usage_history(final_results, info) + self._update_usage_history(final_results, info, user_name) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") res_results = "" @@ -104,7 +112,15 @@ def search( return final_results @timed - def _parse_task(self, query, info, mode, top_k=5, search_filter: dict | None = None): + def _parse_task( + self, + query, + info, + mode, + top_k=5, + search_filter: dict | None = None, + user_name: str | None = None, + ): """Parse user query, do embedding search and create context""" context = [] query_embedding = None @@ -118,7 +134,7 @@ def _parse_task(self, query, info, mode, top_k=5, search_filter: dict | None = N related_nodes = [ self.graph_store.get_node(n["id"]) for n in self.graph_store.search_by_embedding( - query_embedding, top_k=top_k, search_filter=search_filter + query_embedding, top_k=top_k, search_filter=search_filter, user_name=user_name ) ] memories = [] @@ -168,6 +184,7 @@ def _retrieve_paths( mode, memory_type, search_filter: dict | None = None, + user_name: str | None = None, ): """Run A/B/C retrieval paths in parallel""" tasks = [] @@ -181,6 +198,7 @@ def _retrieve_paths( top_k, memory_type, search_filter, + user_name, ) ) tasks.append( @@ -192,6 +210,7 @@ def _retrieve_paths( top_k, memory_type, search_filter, + user_name, ) ) tasks.append( @@ -204,6 +223,7 @@ def _retrieve_paths( info, mode, memory_type, + user_name, ) ) if self.moscube: @@ -235,6 +255,7 @@ def _retrieve_from_working_memory( top_k, memory_type, search_filter: dict | None = None, + user_name: str | None = None, ): """Retrieve and rerank from WorkingMemory""" if memory_type not in ["All", "WorkingMemory"]: @@ -246,6 +267,7 @@ def _retrieve_from_working_memory( top_k=top_k, memory_scope="WorkingMemory", search_filter=search_filter, + user_name=user_name, ) return self.reranker.rerank( query=query, @@ -266,6 +288,7 @@ def _retrieve_from_long_term_and_user( top_k, memory_type, search_filter: dict | None = None, + user_name: str | None = None, ): """Retrieve and rerank from LongTermMemory and UserMemory""" results = [] @@ -282,6 +305,7 @@ def _retrieve_from_long_term_and_user( top_k=top_k * 2, memory_scope="LongTermMemory", search_filter=search_filter, + user_name=user_name, ) ) if memory_type in ["All", "UserMemory"]: @@ -294,6 +318,7 @@ def _retrieve_from_long_term_and_user( top_k=top_k * 2, memory_scope="UserMemory", search_filter=search_filter, + user_name=user_name, ) ) @@ -320,6 +345,7 @@ def _retrieve_from_memcubes( top_k=top_k * 2, memory_scope="LongTermMemory", cube_name=cube_name, + user_name=cube_name, ) return self.reranker.rerank( query=query, @@ -332,7 +358,15 @@ def _retrieve_from_memcubes( # --- Path C @timed def _retrieve_from_internet( - self, query, parsed_goal, query_embedding, top_k, info, mode, memory_type + self, + query, + parsed_goal, + query_embedding, + top_k, + info, + mode, + memory_type, + user_id: str | None = None, ): """Retrieve and rerank from Internet source""" if not self.internet_retriever or mode == "fast": @@ -380,7 +414,7 @@ def _sort_and_trim(self, results, top_k): return final_items @timed - def _update_usage_history(self, items, info): + def _update_usage_history(self, items, info, user_name: str | None = None): """Update usage history in graph DB""" now_time = datetime.now().isoformat() info_copy = dict(info or {}) @@ -402,11 +436,15 @@ def _update_usage_history(self, items, info): logger.exception("[USAGE] snapshot item failed") if payload: - self._usage_executor.submit(self._update_usage_history_worker, payload, usage_record) + self._usage_executor.submit( + self._update_usage_history_worker, payload, usage_record, user_name + ) - def _update_usage_history_worker(self, payload, usage_record: str): + def _update_usage_history_worker( + self, payload, usage_record: str, user_name: str | None = None + ): try: for item_id, usage_list in payload: - self.graph_store.update_node(item_id, {"usage": usage_list}) + self.graph_store.update_node(item_id, {"usage": usage_list}, user_name=user_name) except Exception: logger.exception("[USAGE] update usage failed") diff --git a/src/memos/types.py b/src/memos/types.py index 60d5da8d2..635fabccc 100644 --- a/src/memos/types.py +++ b/src/memos/types.py @@ -56,3 +56,25 @@ class MOSSearchResult(TypedDict): text_mem: list[dict[str, str | list[TextualMemoryItem]]] act_mem: list[dict[str, str | list[ActivationMemoryItem]]] para_mem: list[dict[str, str | list[ParametricMemoryItem]]] + + +# ─── API Types ──────────────────────────────────────────────────────────────────── +# for API Permission +Permission: TypeAlias = Literal["read", "write", "delete", "execute"] + + +# Message structure +class PermissionDict(TypedDict, total=False): + """Typed dictionary for chat message dictionaries.""" + + permissions: list[Permission] + mem_cube_id: str + + +class UserContext(BaseModel): + """Model to represent user context.""" + + user_id: str | None = None + mem_cube_id: str | None = None + session_id: str | None = None + operation: list[PermissionDict] | None = None diff --git a/tests/memories/textual/test_tree_searcher.py b/tests/memories/textual/test_tree_searcher.py index c9f42ec38..d99664817 100644 --- a/tests/memories/textual/test_tree_searcher.py +++ b/tests/memories/textual/test_tree_searcher.py @@ -73,7 +73,7 @@ def test_searcher_fast_path(mock_searcher): for item in result: assert len(item.metadata.usage) > 0 mock_searcher.graph_store.update_node.assert_any_call( - item.id, {"usage": item.metadata.usage} + item.id, {"usage": item.metadata.usage}, user_name=None )