diff --git a/.gitignore b/.gitignore index ab7848f74..c972d59de 100644 --- a/.gitignore +++ b/.gitignore @@ -234,6 +234,6 @@ apps/openwork-memos-integration/apps/desktop/public/assets/usecases/ # Outputs and Evaluation Results outputs -evaluation/data/temporal_locomo +evaluation/data/ test_add_pipeline.py test_file_pipeline.py diff --git a/Makefile b/Makefile index 788504a73..eb22e241d 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ pre_commit: poetry run pre-commit run -a serve: - poetry run uvicorn memos.api.start_api:app + poetry run uvicorn memos.api.server_api:app openapi: poetry run memos export_openapi --output docs/openapi.json diff --git a/README.md b/README.md index a7b05d683..1c7d5bd93 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ - [**72% lower token usage**](https://x.com/MemOS_dev/status/2020854044583924111) โ€” intelligent memory retrieval instead of loading full chat history - [**Multi-agent memory sharing**](https://x.com/MemOS_dev/status/2020538135487062094) โ€” multi-instance agents share memory via same user_id, automatic context handoff -Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/) +Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/) Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin) ### ๐Ÿง  Local Plugin โ€” 100% On-Device Memory @@ -84,7 +84,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem - **Hybrid search + task & skill evolution** โ€” FTS5 + vector search, auto task summarization, reusable skills that self-upgrade - **Multi-agent collaboration + Memory Viewer** โ€” memory isolation, skill sharing, full web dashboard with 7 management pages - ๐ŸŒ [Homepage](https://memos-claw.openmem.net) ยท + ๐ŸŒ [Homepage](https://memos-claw.openmem.net) ยท ๐Ÿ“– [Documentation](https://memos-claw.openmem.net/docs/index.html) ยท ๐Ÿ“ฆ [NPM](https://www.npmjs.com/package/@memtensor/memos-local-openclaw-plugin) ## ๐Ÿ“Œ MemOS: Memory Operating System for AI Agents @@ -104,10 +104,10 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem ### News -- **2026-03-08** ยท ๐Ÿฆž **MemOS OpenClaw Plugin โ€” Cloud & Local** +- **2026-03-08** ยท ๐Ÿฆž **MemOS OpenClaw Plugin โ€” Cloud & Local** Official OpenClaw memory plugins launched. **Cloud Plugin**: hosted memory service with 72% lower token usage and multi-agent memory sharing ([MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin)). **Local Plugin** (`v1.0.0`): 100% on-device memory with persistent SQLite, hybrid search (FTS5 + vector), task summarization & skill evolution, multi-agent collaboration, and a full Memory Viewer dashboard. -- **2025-12-24** ยท ๐ŸŽ‰ **MemOS v2.0: Stardust (ๆ˜Ÿๅฐ˜) Release** +- **2025-12-24** ยท ๐ŸŽ‰ **MemOS v2.0: Stardust (ๆ˜Ÿๅฐ˜) Release** Comprehensive KB (doc/URL parsing + cross-project sharing), memory feedback & precise deletion, multi-modal memory (images/charts), tool memory for agent planning, Redis Streams scheduling + DB optimizations, streaming/non-streaming chat, MCP upgrade, and lightweight quick/full deployment.
โœจ New Features @@ -155,7 +155,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- **2025-08-07** ยท ๐ŸŽ‰ **MemOS v1.0.0 (MemCube) Release** - First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, NebulaGraph support, improved search capabilities, and the official Playground launch. + First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, improved search capabilities, and the official Playground launch.
โœจ New Features @@ -176,7 +176,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem **Plaintext Memory** - Integrated internet search with Bocha. - - Added support for Nebula database. + - Expanded graph database support. - Added contextual understanding for the tree-structured plaintext memory search interface.
@@ -188,7 +188,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem - Fixed the concat_cache method. **Plaintext Memory** - - Fixed Nebula search-related issues. + - Fixed graph search-related issues. diff --git a/docker/.env.example b/docker/.env.example index 3674cd69b..08f9cf1db 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -127,7 +127,7 @@ MEMSCHEDULER_USE_REDIS_QUEUE=false ## Graph / vector stores # Neo4j database selection mode -NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | nebular | polardb +NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | polardb | postgres # Neo4j database url NEO4J_URI=bolt://localhost:7687 # required when backend=neo4j* # Neo4j database user diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index 6dbe202c2..d14b6d687 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -47,7 +47,7 @@ def build_minimal_components(): # Build component configurations using APIConfig methods (like config_builders.py) - # Graph DB configuration - using APIConfig.get_nebular_config() + # Graph DB configuration - using APIConfig graph DB helpers graph_db_backend = os.getenv("NEO4J_BACKEND", "polardb").lower() graph_db_backend_map = { "polardb": APIConfig.get_polardb_config(), diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 87f1efd8e..2e8bae57c 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -741,21 +741,6 @@ def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]: "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), } - @staticmethod - def get_nebular_config(user_id: str | None = None) -> dict[str, Any]: - """Get Nebular configuration.""" - return { - "uri": json.loads(os.getenv("NEBULAR_HOSTS", '["localhost"]')), - "user": os.getenv("NEBULAR_USER", "root"), - "password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"), - "space": os.getenv("NEBULAR_SPACE", "shared-tree-textual-memory"), - "user_name": f"memos{user_id.replace('-', '')}", - "use_multi_db": False, - "auto_create": True, - "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), - } - - @staticmethod def get_milvus_config(): return { "collection_name": [ @@ -1103,7 +1088,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene neo4j_community_config = APIConfig.get_neo4j_community_config(user_id) neo4j_config = APIConfig.get_neo4j_config(user_id) - nebular_config = APIConfig.get_nebular_config(user_id) polardb_config = APIConfig.get_polardb_config(user_id) internet_config = ( APIConfig.get_internet_config() @@ -1114,7 +1098,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, - "nebular": nebular_config, "polardb": polardb_config, "postgres": postgres_config, } @@ -1144,9 +1127,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene "reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower() == "true", "memory_size": { - "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)), - "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)), - "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)), + "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)), + "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)), + "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)), }, "search_strategy": { "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), @@ -1169,7 +1152,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene } ) else: - raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}") + raise ValueError(f"Invalid graph DB backend: {graph_db_backend}") default_mem_cube = GeneralMemCube(default_cube_config) return default_config, default_mem_cube @@ -1188,13 +1171,11 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": openai_config = APIConfig.get_openai_config() neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default") neo4j_config = APIConfig.get_neo4j_config(user_id="default") - nebular_config = APIConfig.get_nebular_config(user_id="default") polardb_config = APIConfig.get_polardb_config(user_id="default") postgres_config = APIConfig.get_postgres_config(user_id="default") graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, - "nebular": nebular_config, "polardb": polardb_config, "postgres": postgres_config, } @@ -1227,9 +1208,9 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": == "true", "internet_retriever": internet_config, "memory_size": { - "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)), - "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)), - "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)), + "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)), + "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)), + "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)), }, "search_strategy": { "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), @@ -1253,4 +1234,4 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": } ) else: - raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}") + raise ValueError(f"Invalid graph DB backend: {graph_db_backend}") diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index e071eacb3..eb982fd06 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -36,7 +36,6 @@ def __init__( vector_db: Any | None = None, internet_retriever: Any | None = None, memory_manager: Any | None = None, - mos_server: Any | None = None, feedback_server: Any | None = None, **kwargs, ): @@ -54,7 +53,6 @@ def __init__( vector_db: Vector database instance internet_retriever: Internet retriever instance memory_manager: Memory manager instance - mos_server: MOS server instance **kwargs: Additional dependencies """ self.llm = llm @@ -68,7 +66,6 @@ def __init__( self.vector_db = vector_db self.internet_retriever = internet_retriever self.memory_manager = memory_manager - self.mos_server = mos_server self.feedback_server = feedback_server # Store any additional dependencies @@ -158,11 +155,6 @@ def vector_db(self): """Get vector database instance.""" return self.deps.vector_db - @property - def mos_server(self): - """Get MOS server instance.""" - return self.deps.mos_server - @property def deepsearch_agent(self): """Get deepsearch agent instance.""" diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 7894ff7dc..a01fffef8 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -28,7 +28,6 @@ from memos.log import get_logger from memos.mem_cube.navie import NaiveMemCube from memos.mem_feedback.simple_feedback import SimpleMemFeedback -from memos.mem_os.product_server import MOSServer from memos.mem_reader.factory import MemReaderFactory from memos.mem_scheduler.orm_modules.base_model import BaseDBManager from memos.mem_scheduler.scheduler_factory import SchedulerFactory @@ -211,15 +210,6 @@ def init_server() -> dict[str, Any]: logger.debug("Text memory initialized") - # Initialize MOS Server - mos_server = MOSServer( - mem_reader=mem_reader, - llm=llm, - online_bot=False, - ) - - logger.debug("MOS server initialized") - # Create MemCube with pre-initialized memory instances naive_mem_cube = NaiveMemCube( text_mem=text_mem, @@ -304,7 +294,6 @@ def init_server() -> dict[str, Any]: "internet_retriever": internet_retriever, "memory_manager": memory_manager, "default_cube_config": default_cube_config, - "mos_server": mos_server, "mem_scheduler": mem_scheduler, "naive_mem_cube": naive_mem_cube, "searcher": searcher, diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index 5655bf1e5..d29429fc9 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -39,13 +39,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: graph_db_backend_map = { "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), "neo4j": APIConfig.get_neo4j_config(user_id=user_id), - "nebular": APIConfig.get_nebular_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), "postgres": APIConfig.get_postgres_config(user_id=user_id), } # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars - graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower() + graph_db_backend = os.getenv( + "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community") + ).lower() return GraphDBConfigFactory.model_validate( { "backend": graph_db_backend, diff --git a/src/memos/api/product_api.py b/src/memos/api/product_api.py deleted file mode 100644 index ec5cccae1..000000000 --- a/src/memos/api/product_api.py +++ /dev/null @@ -1,38 +0,0 @@ -import logging - -from fastapi import FastAPI - -from memos.api.exceptions import APIExceptionHandler -from memos.api.middleware.request_context import RequestContextMiddleware -from memos.api.routers.product_router import router as product_router - - -# Configure logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) - -app = FastAPI( - title="MemOS Product REST APIs", - description="A REST API for managing multiple users with MemOS Product.", - version="1.0.1", -) - -app.add_middleware(RequestContextMiddleware, source="product_api") -# Include routers -app.include_router(product_router) - -# Exception handlers -app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler) -app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) - - -if __name__ == "__main__": - import argparse - - import uvicorn - - parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, default=8001) - parser.add_argument("--workers", type=int, default=1) - args = parser.parse_args() - uvicorn.run("memos.api.product_api:app", host="0.0.0.0", port=args.port, workers=args.workers) diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py deleted file mode 100644 index 609d61124..000000000 --- a/src/memos/api/routers/product_router.py +++ /dev/null @@ -1,477 +0,0 @@ -import json -import time -import traceback - -from fastapi import APIRouter, HTTPException -from fastapi.responses import StreamingResponse - -from memos.api.config import APIConfig -from memos.api.product_models import ( - BaseResponse, - ChatCompleteRequest, - ChatRequest, - GetMemoryPlaygroundRequest, - MemoryCreateRequest, - MemoryResponse, - SearchRequest, - SearchResponse, - SimpleResponse, - SuggestionRequest, - SuggestionResponse, - UserRegisterRequest, - UserRegisterResponse, -) -from memos.configs.mem_os import MOSConfig -from memos.log import get_logger -from memos.mem_os.product import MOSProduct -from memos.memos_tools.notification_service import get_error_bot_function, get_online_bot_function - - -logger = get_logger(__name__) - -router = APIRouter(prefix="/product", tags=["Product API"]) - -# Initialize MOSProduct instance with lazy initialization -MOS_PRODUCT_INSTANCE = None - - -def get_mos_product_instance(): - """Get or create MOSProduct instance.""" - global MOS_PRODUCT_INSTANCE - if MOS_PRODUCT_INSTANCE is None: - default_config = APIConfig.get_product_default_config() - logger.info(f"*********init_default_mos_config********* {default_config}") - from memos.configs.mem_os import MOSConfig - - mos_config = MOSConfig(**default_config) - - # Get default cube config from APIConfig (may be None if disabled) - default_cube_config = APIConfig.get_default_cube_config() - logger.info(f"*********initdefault_cube_config******** {default_cube_config}") - - # Get DingDing bot functions - dingding_enabled = APIConfig.is_dingding_bot_enabled() - online_bot = get_online_bot_function() if dingding_enabled else None - error_bot = get_error_bot_function() if dingding_enabled else None - - MOS_PRODUCT_INSTANCE = MOSProduct( - default_config=mos_config, - default_cube_config=default_cube_config, - online_bot=online_bot, - error_bot=error_bot, - ) - logger.info("MOSProduct instance created successfully with inheritance architecture") - return MOS_PRODUCT_INSTANCE - - -get_mos_product_instance() - - -@router.post("/configure", summary="Configure MOSProduct", response_model=SimpleResponse) -def set_config(config): - """Set MOSProduct configuration.""" - global MOS_PRODUCT_INSTANCE - MOS_PRODUCT_INSTANCE = MOSProduct(default_config=config) - return SimpleResponse(message="Configuration set successfully") - - -@router.post("/users/register", summary="Register a new user", response_model=UserRegisterResponse) -def register_user(user_req: UserRegisterRequest): - """Register a new user with configuration and default cube.""" - try: - # Get configuration for the user - time_start_register = time.time() - user_config, default_mem_cube = APIConfig.create_user_config( - user_name=user_req.user_id, user_id=user_req.user_id - ) - logger.info(f"user_config: {user_config.model_dump(mode='json')}") - logger.info(f"default_mem_cube: {default_mem_cube.config.model_dump(mode='json')}") - logger.info( - f"time register api : create user config time user_id: {user_req.user_id} time is: {time.time() - time_start_register}" - ) - mos_product = get_mos_product_instance() - - # Register user with default config and mem cube - result = mos_product.user_register( - user_id=user_req.user_id, - user_name=user_req.user_name, - interests=user_req.interests, - config=user_config, - default_mem_cube=default_mem_cube, - mem_cube_id=user_req.mem_cube_id, - ) - logger.info( - f"time register api : register time user_id: {user_req.user_id} time is: {time.time() - time_start_register}" - ) - if result["status"] == "success": - return UserRegisterResponse( - message="User registered successfully", - data={"user_id": result["user_id"], "mem_cube_id": result["default_cube_id"]}, - ) - else: - raise HTTPException(status_code=400, detail=result["message"]) - - except Exception as err: - logger.error(f"Failed to register user: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get( - "/suggestions/{user_id}", summary="Get suggestion queries", response_model=SuggestionResponse -) -def get_suggestion_queries(user_id: str): - """Get suggestion queries for a specific user.""" - try: - mos_product = get_mos_product_instance() - suggestions = mos_product.get_suggestion_query(user_id) - return SuggestionResponse( - message="Suggestions retrieved successfully", data={"query": suggestions} - ) - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to get suggestions: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post( - "/suggestions", - summary="Get suggestion queries with language", - response_model=SuggestionResponse, -) -def get_suggestion_queries_post(suggestion_req: SuggestionRequest): - """Get suggestion queries for a specific user with language preference.""" - try: - mos_product = get_mos_product_instance() - suggestions = mos_product.get_suggestion_query( - user_id=suggestion_req.user_id, - language=suggestion_req.language, - message=suggestion_req.message, - ) - return SuggestionResponse( - message="Suggestions retrieved successfully", data={"query": suggestions} - ) - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to get suggestions: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse) -def get_all_memories(memory_req: GetMemoryPlaygroundRequest): - """Get all memories for a specific user.""" - try: - mos_product = get_mos_product_instance() - if memory_req.search_query: - result = mos_product.get_subgraph( - user_id=memory_req.user_id, - query=memory_req.search_query, - mem_cube_ids=memory_req.mem_cube_ids, - ) - return MemoryResponse(message="Memories retrieved successfully", data=result) - else: - result = mos_product.get_all( - user_id=memory_req.user_id, - memory_type=memory_req.memory_type, - mem_cube_ids=memory_req.mem_cube_ids, - ) - return MemoryResponse(message="Memories retrieved successfully", data=result) - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to get memories: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post("/add", summary="add a new memory", response_model=SimpleResponse) -def create_memory(memory_req: MemoryCreateRequest): - """Create a new memory for a specific user.""" - logger.info("DIAGNOSTIC: /product/add endpoint called. This confirms the new code is deployed.") - # Initialize status_tracker outside try block to avoid NameError in except blocks - status_tracker = None - - try: - time_start_add = time.time() - mos_product = get_mos_product_instance() - - # Track task if task_id is provided - item_id: str | None = None - if ( - memory_req.task_id - and hasattr(mos_product, "mem_scheduler") - and mos_product.mem_scheduler - ): - from uuid import uuid4 - - from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker - - item_id = str(uuid4()) # Generate a unique item_id for this submission - - # Get Redis client from scheduler - if ( - hasattr(mos_product.mem_scheduler, "redis_client") - and mos_product.mem_scheduler.redis_client - ): - status_tracker = TaskStatusTracker(mos_product.mem_scheduler.redis_client) - # Submit task with "product_add" type - status_tracker.task_submitted( - task_id=item_id, # Use generated item_id for internal tracking - user_id=memory_req.user_id, - task_type="product_add", - mem_cube_id=memory_req.mem_cube_id or memory_req.user_id, - business_task_id=memory_req.task_id, # Use memory_req.task_id as business_task_id - ) - status_tracker.task_started(item_id, memory_req.user_id) # Use item_id here - - # Execute the add operation - mos_product.add( - user_id=memory_req.user_id, - memory_content=memory_req.memory_content, - messages=memory_req.messages, - doc_path=memory_req.doc_path, - mem_cube_id=memory_req.mem_cube_id, - source=memory_req.source, - user_profile=memory_req.user_profile, - session_id=memory_req.session_id, - task_id=memory_req.task_id, - ) - - # Mark task as completed - if status_tracker and item_id: - status_tracker.task_completed(item_id, memory_req.user_id) - - logger.info( - f"time add api : add time user_id: {memory_req.user_id} time is: {time.time() - time_start_add}" - ) - return SimpleResponse(message="Memory created successfully") - - except ValueError as err: - # Mark task as failed if tracking - if status_tracker and item_id: - status_tracker.task_failed(item_id, memory_req.user_id, str(err)) - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - # Mark task as failed if tracking - if status_tracker and item_id: - status_tracker.task_failed(item_id, memory_req.user_id, str(err)) - logger.error(f"Failed to create memory: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post("/search", summary="Search memories", response_model=SearchResponse) -def search_memories(search_req: SearchRequest): - """Search memories for a specific user.""" - try: - time_start_search = time.time() - mos_product = get_mos_product_instance() - result = mos_product.search( - query=search_req.query, - user_id=search_req.user_id, - install_cube_ids=[search_req.mem_cube_id] if search_req.mem_cube_id else None, - top_k=search_req.top_k, - session_id=search_req.session_id, - ) - logger.info( - f"time search api : add time user_id: {search_req.user_id} time is: {time.time() - time_start_search}" - ) - return SearchResponse(message="Search completed successfully", data=result) - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to search memories: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post("/chat", summary="Chat with MemOS") -def chat(chat_req: ChatRequest): - """Chat with MemOS for a specific user. Returns SSE stream.""" - try: - mos_product = get_mos_product_instance() - - def generate_chat_response(): - """Generate chat response as SSE stream.""" - try: - # Directly yield from the generator without async wrapper - yield from mos_product.chat_with_references( - query=chat_req.query, - user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, - history=chat_req.history, - internet_search=chat_req.internet_search, - moscube=chat_req.moscube, - session_id=chat_req.session_id, - ) - - except Exception as e: - logger.error(f"Error in chat stream: {e}") - error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" - yield error_data - - return StreamingResponse( - generate_chat_response(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "text/event-stream", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Headers": "*", - "Access-Control-Allow-Methods": "*", - }, - ) - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to start chat: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.post("/chat/complete", summary="Chat with MemOS (Complete Response)") -def chat_complete(chat_req: ChatCompleteRequest): - """Chat with MemOS for a specific user. Returns complete response (non-streaming).""" - try: - mos_product = get_mos_product_instance() - - # Collect all responses from the generator - content, references = mos_product.chat( - query=chat_req.query, - user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, - history=chat_req.history, - internet_search=chat_req.internet_search, - moscube=chat_req.moscube, - base_prompt=chat_req.base_prompt or chat_req.system_prompt, - # will deprecate base_prompt in the future - top_k=chat_req.top_k, - threshold=chat_req.threshold, - session_id=chat_req.session_id, - ) - - # Return the complete response - return { - "message": "Chat completed successfully", - "data": {"response": content, "references": references}, - } - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to start chat: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get("/users", summary="List all users", response_model=BaseResponse[list]) -def list_users(): - """List all registered users.""" - try: - mos_product = get_mos_product_instance() - users = mos_product.list_users() - return BaseResponse(message="Users retrieved successfully", data=users) - except Exception as err: - logger.error(f"Failed to list users: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get("/users/{user_id}", summary="Get user info", response_model=BaseResponse[dict]) -async def get_user_info(user_id: str): - """Get user information including accessible cubes.""" - try: - mos_product = get_mos_product_instance() - user_info = mos_product.get_user_info(user_id) - return BaseResponse(message="User info retrieved successfully", data=user_info) - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to get user info: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get( - "/configure/{user_id}", summary="Get MOSProduct configuration", response_model=SimpleResponse -) -def get_config(user_id: str): - """Get MOSProduct configuration.""" - global MOS_PRODUCT_INSTANCE - config = MOS_PRODUCT_INSTANCE.default_config - return SimpleResponse(message="Configuration retrieved successfully", data=config) - - -@router.get( - "/users/{user_id}/config", summary="Get user configuration", response_model=BaseResponse[dict] -) -def get_user_config(user_id: str): - """Get user-specific configuration.""" - try: - mos_product = get_mos_product_instance() - config = mos_product.get_user_config(user_id) - if config: - return BaseResponse( - message="User configuration retrieved successfully", - data=config.model_dump(mode="json"), - ) - else: - raise HTTPException( - status_code=404, detail=f"Configuration not found for user {user_id}" - ) - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to get user config: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.put( - "/users/{user_id}/config", summary="Update user configuration", response_model=SimpleResponse -) -def update_user_config(user_id: str, config_data: dict): - """Update user-specific configuration.""" - try: - mos_product = get_mos_product_instance() - - # Create MOSConfig from the provided data - config = MOSConfig(**config_data) - - # Update the configuration - success = mos_product.update_user_config(user_id, config) - if success: - return SimpleResponse(message="User configuration updated successfully") - else: - raise HTTPException(status_code=500, detail="Failed to update user configuration") - - except ValueError as err: - raise HTTPException(status_code=400, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to update user config: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get( - "/instances/status", summary="Get user configuration status", response_model=BaseResponse[dict] -) -def get_instance_status(): - """Get information about active user configurations in memory.""" - try: - mos_product = get_mos_product_instance() - status_info = mos_product.get_user_instance_info() - return BaseResponse( - message="User configuration status retrieved successfully", data=status_info - ) - except Exception as err: - logger.error(f"Failed to get user configuration status: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -@router.get("/instances/count", summary="Get active user count", response_model=BaseResponse[int]) -def get_active_user_count(): - """Get the number of active user configurations in memory.""" - try: - mos_product = get_mos_product_instance() - count = mos_product.get_active_user_count() - return BaseResponse(message="Active user count retrieved successfully", data=count) - except Exception as err: - logger.error(f"Failed to get active user count: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err diff --git a/src/memos/api/start_api.py b/src/memos/api/start_api.py deleted file mode 100644 index 24a36f017..000000000 --- a/src/memos/api/start_api.py +++ /dev/null @@ -1,433 +0,0 @@ -import logging -import os - -from typing import Any, Generic, TypeVar - -from dotenv import load_dotenv -from fastapi import FastAPI -from fastapi.requests import Request -from fastapi.responses import JSONResponse, RedirectResponse -from pydantic import BaseModel, Field - -from memos.api.middleware.request_context import RequestContextMiddleware -from memos.configs.mem_os import MOSConfig -from memos.mem_os.main import MOS -from memos.mem_user.user_manager import UserManager, UserRole - - -# Configure logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) - -# Load environment variables -load_dotenv(override=True) - -T = TypeVar("T") - -# Default configuration -DEFAULT_CONFIG = { - "user_id": os.getenv("MOS_USER_ID", "default_user"), - "session_id": os.getenv("MOS_SESSION_ID", "default_session"), - "enable_textual_memory": True, - "enable_activation_memory": False, - "top_k": int(os.getenv("MOS_TOP_K", "5")), - "chat_model": { - "backend": os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai"), - "config": { - "model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-3.5-turbo"), - "api_key": os.getenv("OPENAI_API_KEY", "apikey"), - "temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.7")), - "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), - }, - }, -} - -# Initialize MOS instance with lazy initialization -MOS_INSTANCE = None - - -def get_mos_instance(): - """Get or create MOS instance with default user creation.""" - global MOS_INSTANCE - if MOS_INSTANCE is None: - # Create a temporary MOS instance to access user manager - temp_config = MOSConfig(**DEFAULT_CONFIG) - temp_mos = MOS.__new__(MOS) - temp_mos.config = temp_config - temp_mos.user_id = temp_config.user_id - temp_mos.session_id = temp_config.session_id - temp_mos.mem_cubes = {} - temp_mos.chat_llm = None # Will be initialized later - temp_mos.user_manager = UserManager() - - # Create default user if it doesn't exist - if not temp_mos.user_manager.validate_user(temp_config.user_id): - temp_mos.user_manager.create_user( - user_name=temp_config.user_id, role=UserRole.USER, user_id=temp_config.user_id - ) - logger.info(f"Created default user: {temp_config.user_id}") - - # Now create the actual MOS instance - MOS_INSTANCE = MOS(config=temp_config) - - return MOS_INSTANCE - - -app = FastAPI( - title="MemOS REST APIs", - description="A REST API for managing and searching memories using MemOS.", - version="1.0.0", -) - -app.add_middleware(RequestContextMiddleware) - - -class BaseRequest(BaseModel): - """Base model for all requests.""" - - user_id: str | None = Field( - None, description="User ID for the request", json_schema_extra={"example": "user123"} - ) - - -class BaseResponse(BaseModel, Generic[T]): - """Base model for all responses.""" - - code: int = Field(200, description="Response status code", json_schema_extra={"example": 200}) - message: str = Field( - ..., description="Response message", json_schema_extra={"example": "Operation successful"} - ) - data: T | None = Field(None, description="Response data") - - -class Message(BaseModel): - role: str = Field( - ..., - description="Role of the message (user or assistant).", - json_schema_extra={"example": "user"}, - ) - content: str = Field( - ..., - description="Message content.", - json_schema_extra={"example": "Hello, how can I help you?"}, - ) - - -class MemoryCreate(BaseRequest): - messages: list[Message] | None = Field( - None, - description="List of messages to store.", - json_schema_extra={"example": [{"role": "user", "content": "Hello"}]}, - ) - mem_cube_id: str | None = Field( - None, description="ID of the memory cube", json_schema_extra={"example": "cube123"} - ) - memory_content: str | None = Field( - None, - description="Content to store as memory", - json_schema_extra={"example": "This is a memory content"}, - ) - doc_path: str | None = Field( - None, - description="Path to document to store", - json_schema_extra={"example": "/path/to/document.txt"}, - ) - - -class SearchRequest(BaseRequest): - query: str = Field( - ..., - description="Search query.", - json_schema_extra={"example": "How to implement a feature?"}, - ) - install_cube_ids: list[str] | None = Field( - None, - description="List of cube IDs to search in", - json_schema_extra={"example": ["cube123", "cube456"]}, - ) - - -class MemCubeRegister(BaseRequest): - mem_cube_name_or_path: str = Field( - ..., - description="Name or path of the MemCube to register.", - json_schema_extra={"example": "/path/to/cube"}, - ) - mem_cube_id: str | None = Field( - None, description="ID for the MemCube", json_schema_extra={"example": "cube123"} - ) - - -class ChatRequest(BaseRequest): - query: str = Field( - ..., - description="Chat query message.", - json_schema_extra={"example": "What is the latest update?"}, - ) - - -class UserCreate(BaseRequest): - user_name: str | None = Field( - None, description="Name of the user", json_schema_extra={"example": "john_doe"} - ) - role: str = Field("user", description="Role of the user", json_schema_extra={"example": "user"}) - user_id: str = Field(..., description="User ID", json_schema_extra={"example": "user123"}) - - -class CubeShare(BaseRequest): - target_user_id: str = Field( - ..., description="Target user ID to share with", json_schema_extra={"example": "user456"} - ) - - -class SimpleResponse(BaseResponse[None]): - """Simple response model for operations without data return.""" - - -class ConfigResponse(BaseResponse[None]): - """Response model for configuration endpoint.""" - - -class MemoryResponse(BaseResponse[dict]): - """Response model for memory operations.""" - - -class SearchResponse(BaseResponse[dict]): - """Response model for search operations.""" - - -class ChatResponse(BaseResponse[str]): - """Response model for chat operations.""" - - -class UserResponse(BaseResponse[dict]): - """Response model for user operations.""" - - -class UserListResponse(BaseResponse[list]): - """Response model for user list operations.""" - - -@app.post("/configure", summary="Configure MemOS", response_model=ConfigResponse) -async def set_config(config: MOSConfig): - """Set MemOS configuration.""" - global MOS_INSTANCE - - # Create a temporary user manager to check/create default user - temp_user_manager = UserManager() - - # Create default user if it doesn't exist - if not temp_user_manager.validate_user(config.user_id): - temp_user_manager.create_user( - user_name=config.user_id, role=UserRole.USER, user_id=config.user_id - ) - logger.info(f"Created default user: {config.user_id}") - - # Now create the MOS instance - MOS_INSTANCE = MOS(config=config) - return ConfigResponse(message="Configuration set successfully") - - -@app.post("/users", summary="Create a new user", response_model=UserResponse) -async def create_user(user_create: UserCreate): - """Create a new user.""" - mos_instance = get_mos_instance() - role = UserRole(user_create.role) - user_id = mos_instance.create_user( - user_id=user_create.user_id, role=role, user_name=user_create.user_name - ) - return UserResponse(message="User created successfully", data={"user_id": user_id}) - - -@app.get("/users", summary="List all users", response_model=UserListResponse) -async def list_users(): - """List all active users.""" - mos_instance = get_mos_instance() - users = mos_instance.list_users() - return UserListResponse(message="Users retrieved successfully", data=users) - - -@app.get("/users/me", summary="Get current user info", response_model=UserResponse) -async def get_user_info(): - """Get current user information including accessible cubes.""" - mos_instance = get_mos_instance() - user_info = mos_instance.get_user_info() - return UserResponse(message="User info retrieved successfully", data=user_info) - - -@app.post("/mem_cubes", summary="Register a MemCube", response_model=SimpleResponse) -async def register_mem_cube(mem_cube: MemCubeRegister): - """Register a new MemCube.""" - mos_instance = get_mos_instance() - mos_instance.register_mem_cube( - mem_cube_name_or_path=mem_cube.mem_cube_name_or_path, - mem_cube_id=mem_cube.mem_cube_id, - user_id=mem_cube.user_id, - ) - return SimpleResponse(message="MemCube registered successfully") - - -@app.delete( - "/mem_cubes/{mem_cube_id}", summary="Unregister a MemCube", response_model=SimpleResponse -) -async def unregister_mem_cube(mem_cube_id: str, user_id: str | None = None): - """Unregister a MemCube.""" - mos_instance = get_mos_instance() - mos_instance.unregister_mem_cube(mem_cube_id=mem_cube_id, user_id=user_id) - return SimpleResponse(message="MemCube unregistered successfully") - - -@app.post( - "/mem_cubes/{cube_id}/share", - summary="Share a cube with another user", - response_model=SimpleResponse, -) -async def share_cube(cube_id: str, share_request: CubeShare): - """Share a cube with another user.""" - mos_instance = get_mos_instance() - success = mos_instance.share_cube_with_user(cube_id, share_request.target_user_id) - if success: - return SimpleResponse(message="Cube shared successfully") - else: - raise ValueError("Failed to share cube") - - -@app.post("/memories", summary="Create memories", response_model=SimpleResponse) -async def add_memory(memory_create: MemoryCreate): - """Store new memories in a MemCube.""" - if not any([memory_create.messages, memory_create.memory_content, memory_create.doc_path]): - raise ValueError("Either messages, memory_content, or doc_path must be provided") - mos_instance = get_mos_instance() - if memory_create.messages: - messages = [m.model_dump() for m in memory_create.messages] - mos_instance.add( - messages=messages, - mem_cube_id=memory_create.mem_cube_id, - user_id=memory_create.user_id, - ) - elif memory_create.memory_content: - mos_instance.add( - memory_content=memory_create.memory_content, - mem_cube_id=memory_create.mem_cube_id, - user_id=memory_create.user_id, - ) - elif memory_create.doc_path: - mos_instance.add( - doc_path=memory_create.doc_path, - mem_cube_id=memory_create.mem_cube_id, - user_id=memory_create.user_id, - ) - return SimpleResponse(message="Memories added successfully") - - -@app.get("/memories", summary="Get all memories", response_model=MemoryResponse) -async def get_all_memories( - mem_cube_id: str | None = None, - user_id: str | None = None, -): - """Retrieve all memories from a MemCube.""" - mos_instance = get_mos_instance() - result = mos_instance.get_all(mem_cube_id=mem_cube_id, user_id=user_id) - return MemoryResponse(message="Memories retrieved successfully", data=result) - - -@app.get( - "/memories/{mem_cube_id}/{memory_id}", summary="Get a memory", response_model=MemoryResponse -) -async def get_memory(mem_cube_id: str, memory_id: str, user_id: str | None = None): - """Retrieve a specific memory by ID from a MemCube.""" - mos_instance = get_mos_instance() - result = mos_instance.get(mem_cube_id=mem_cube_id, memory_id=memory_id, user_id=user_id) - return MemoryResponse(message="Memory retrieved successfully", data=result) - - -@app.post("/search", summary="Search memories", response_model=SearchResponse) -async def search_memories(search_req: SearchRequest): - """Search for memories across MemCubes.""" - mos_instance = get_mos_instance() - result = mos_instance.search( - query=search_req.query, - user_id=search_req.user_id, - install_cube_ids=search_req.install_cube_ids, - ) - return SearchResponse(message="Search completed successfully", data=result) - - -@app.put( - "/memories/{mem_cube_id}/{memory_id}", summary="Update a memory", response_model=SimpleResponse -) -async def update_memory( - mem_cube_id: str, memory_id: str, updated_memory: dict[str, Any], user_id: str | None = None -): - """Update an existing memory in a MemCube.""" - mos_instance = get_mos_instance() - mos_instance.update( - mem_cube_id=mem_cube_id, - memory_id=memory_id, - text_memory_item=updated_memory, - user_id=user_id, - ) - return SimpleResponse(message="Memory updated successfully") - - -@app.delete( - "/memories/{mem_cube_id}/{memory_id}", summary="Delete a memory", response_model=SimpleResponse -) -async def delete_memory(mem_cube_id: str, memory_id: str, user_id: str | None = None): - """Delete a specific memory from a MemCube.""" - mos_instance = get_mos_instance() - mos_instance.delete(mem_cube_id=mem_cube_id, memory_id=memory_id, user_id=user_id) - return SimpleResponse(message="Memory deleted successfully") - - -@app.delete("/memories/{mem_cube_id}", summary="Delete all memories", response_model=SimpleResponse) -async def delete_all_memories(mem_cube_id: str, user_id: str | None = None): - """Delete all memories from a MemCube.""" - mos_instance = get_mos_instance() - mos_instance.delete_all(mem_cube_id=mem_cube_id, user_id=user_id) - return SimpleResponse(message="All memories deleted successfully") - - -@app.post("/chat", summary="Chat with MemOS", response_model=ChatResponse) -async def chat(chat_req: ChatRequest): - """Chat with the MemOS system.""" - mos_instance = get_mos_instance() - response = mos_instance.chat(query=chat_req.query, user_id=chat_req.user_id) - if response is None: - raise ValueError("No response generated") - return ChatResponse(message="Chat response generated", data=response) - - -@app.get("/", summary="Redirect to the OpenAPI documentation", include_in_schema=False) -async def home(): - """Redirect to the OpenAPI documentation.""" - return RedirectResponse(url="/docs", status_code=307) - - -@app.exception_handler(ValueError) -async def value_error_handler(request: Request, exc: ValueError): - """Handle ValueError exceptions globally.""" - return JSONResponse( - status_code=400, - content={"code": 400, "message": str(exc), "data": None}, - ) - - -@app.exception_handler(Exception) -async def global_exception_handler(request: Request, exc: Exception): - """Handle all unhandled exceptions globally.""" - logger.exception("Unhandled error:") - return JSONResponse( - status_code=500, - content={"code": 500, "message": str(exc), "data": None}, - ) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") - parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") - parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development") - args = parser.parse_args() diff --git a/src/memos/cli.py b/src/memos/cli.py index 092f2d276..2ead5ab29 100644 --- a/src/memos/cli.py +++ b/src/memos/cli.py @@ -11,9 +11,16 @@ from io import BytesIO +def get_openapi_app(): + """Return the FastAPI app used for OpenAPI export.""" + from memos.api.server_api import app + + return app + + def export_openapi(output: str) -> bool: """Export OpenAPI schema to JSON file.""" - from memos.api.server_api import app + app = get_openapi_app() # Create directory if it doesn't exist if os.path.dirname(output): diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 5900d2357..98de09812 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -103,57 +103,6 @@ def validate_community(self): return self -class NebulaGraphDBConfig(BaseGraphDBConfig): - """ - NebulaGraph-specific configuration. - - Key concepts: - - `space`: Equivalent to a database or namespace. All tag/edge/schema live within a space. - - `user_name`: Used for logical tenant isolation if needed. - - `auto_create`: Whether to automatically create the target space if it does not exist. - - Example: - --- - hosts = ["127.0.0.1:9669"] - user = "root" - password = "nebula" - space = "shared_graph" - user_name = "alice" - """ - - space: str = Field( - ..., description="The name of the target NebulaGraph space (like a database)" - ) - user_name: str | None = Field( - default=None, - description="Logical user or tenant ID for data isolation (optional, used in metadata tagging)", - ) - auto_create: bool = Field( - default=False, - description="Whether to auto-create the space if it does not exist", - ) - use_multi_db: bool = Field( - default=True, - description=( - "If True: use Neo4j's multi-database feature for physical isolation; " - "each user typically gets a separate database. " - "If False: use a single shared database with logical isolation by user_name." - ), - ) - max_client: int = Field( - default=1000, - description=("max_client"), - ) - embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding") - - @model_validator(mode="after") - def validate_config(self): - """Validate config.""" - if not self.space: - raise ValueError("`space` must be provided") - return self - - class PolarDBGraphDBConfig(BaseConfig): """ PolarDB-specific configuration. @@ -299,7 +248,6 @@ class GraphDBConfigFactory(BaseModel): backend_to_class: ClassVar[dict[str, Any]] = { "neo4j": Neo4jGraphDBConfig, "neo4j-community": Neo4jCommunityGraphDBConfig, - "nebular": NebulaGraphDBConfig, "polardb": PolarDBGraphDBConfig, "postgres": PostgresGraphDBConfig, } diff --git a/src/memos/context/context.py b/src/memos/context/context.py index 5c8401732..5347de880 100644 --- a/src/memos/context/context.py +++ b/src/memos/context/context.py @@ -155,7 +155,7 @@ def get_current_user_name() -> str | None: def get_current_source() -> str | None: """ - Get the current request's source (e.g., 'product_api' or 'server_api'). + Get the current request's source (for example, 'server_api'). """ context = _request_context.get() if context: diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py index c207e3190..93b5971ec 100644 --- a/src/memos/graph_dbs/factory.py +++ b/src/memos/graph_dbs/factory.py @@ -2,7 +2,6 @@ from memos.configs.graph_db import GraphDBConfigFactory from memos.graph_dbs.base import BaseGraphDB -from memos.graph_dbs.nebular import NebulaGraphDB from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB from memos.graph_dbs.polardb import PolarDBGraphDB @@ -15,7 +14,6 @@ class GraphStoreFactory(BaseGraphDB): backend_to_class: ClassVar[dict[str, Any]] = { "neo4j": Neo4jGraphDB, "neo4j-community": Neo4jCommunityGraphDB, - "nebular": NebulaGraphDB, "polardb": PolarDBGraphDB, "postgres": PostgresGraphDB, } diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py deleted file mode 100644 index 428d6d09e..000000000 --- a/src/memos/graph_dbs/nebular.py +++ /dev/null @@ -1,1794 +0,0 @@ -import json -import traceback - -from contextlib import suppress -from datetime import datetime -from threading import Lock -from typing import TYPE_CHECKING, Any, ClassVar, Literal - -import numpy as np - -from memos.configs.graph_db import NebulaGraphDBConfig -from memos.dependency import require_python_package -from memos.graph_dbs.base import BaseGraphDB -from memos.log import get_logger -from memos.utils import timed - - -if TYPE_CHECKING: - from nebulagraph_python import ( - NebulaClient, - ) - - -logger = get_logger(__name__) - - -_TRANSIENT_ERR_KEYS = ( - "Session not found", - "Connection not established", - "timeout", - "deadline exceeded", - "Broken pipe", - "EOFError", - "socket closed", - "connection reset", - "connection refused", -) - - -@timed -def _normalize(vec: list[float]) -> list[float]: - v = np.asarray(vec, dtype=np.float32) - norm = np.linalg.norm(v) - return (v / (norm if norm else 1.0)).tolist() - - -@timed -def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: - node_id = item["id"] - memory = item["memory"] - metadata = item.get("metadata", {}) - return node_id, memory, metadata - - -@timed -def _escape_str(value: str) -> str: - out = [] - for ch in value: - code = ord(ch) - if ch == "\\": - out.append("\\\\") - elif ch == '"': - out.append('\\"') - elif ch == "\n": - out.append("\\n") - elif ch == "\r": - out.append("\\r") - elif ch == "\t": - out.append("\\t") - elif ch == "\b": - out.append("\\b") - elif ch == "\f": - out.append("\\f") - elif code < 0x20 or code in (0x2028, 0x2029): - out.append(f"\\u{code:04x}") - else: - out.append(ch) - return "".join(out) - - -@timed -def _format_datetime(value: str | datetime) -> str: - """Ensure datetime is in ISO 8601 format string.""" - if isinstance(value, datetime): - return value.isoformat() - return str(value) - - -@timed -def _normalize_datetime(val): - """ - Normalize datetime to ISO 8601 UTC string with +00:00. - - If val is datetime object -> keep isoformat() (Neo4j) - - If val is string without timezone -> append +00:00 (Nebula) - - Otherwise just str() - """ - if hasattr(val, "isoformat"): - return val.isoformat() - if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")): - return val + "+08:00" - return str(val) - - -class NebulaGraphDB(BaseGraphDB): - """ - NebulaGraph-based implementation of a graph memory store. - """ - - # ====== shared pool cache & refcount ====== - # These are process-local; in a multi-process model each process will - # have its own cache. - _CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {} - _CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {} - _CLIENT_LOCK: ClassVar[Lock] = Lock() - _CLIENT_INIT_DONE: ClassVar[set[str]] = set() - - @staticmethod - def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]: - hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None) - if isinstance(hosts, str): - return [hosts] - return list(hosts or []) - - @staticmethod - def _make_client_key(cfg: NebulaGraphDBConfig) -> str: - hosts = NebulaGraphDB._get_hosts_from_cfg(cfg) - return "|".join( - [ - "nebula-sync", - ",".join(hosts), - str(getattr(cfg, "user", "")), - str(getattr(cfg, "space", "")), - ] - ) - - @classmethod - def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB": - tmp = object.__new__(NebulaGraphDB) - tmp.config = cfg - tmp.db_name = cfg.space - tmp.user_name = None - tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072) - tmp.default_memory_dimension = 3072 - tmp.common_fields = { - "id", - "memory", - "user_name", - "user_id", - "session_id", - "status", - "key", - "confidence", - "tags", - "created_at", - "updated_at", - "memory_type", - "sources", - "source", - "node_type", - "visibility", - "usage", - "background", - } - tmp.base_fields = set(tmp.common_fields) - {"usage"} - tmp.heavy_fields = {"usage"} - tmp.dim_field = ( - f"embedding_{tmp.embedding_dimension}" - if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension) - else "embedding" - ) - tmp.system_db_name = cfg.space - tmp._client = client - tmp._owns_client = False - return tmp - - @classmethod - def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]: - from nebulagraph_python import ( - ConnectionConfig, - NebulaClient, - SessionConfig, - SessionPoolConfig, - ) - - key = cls._make_client_key(cfg) - with cls._CLIENT_LOCK: - client = cls._CLIENT_CACHE.get(key) - if client is None: - # Connection setting - - tmp_client = NebulaClient( - hosts=cfg.uri, - username=cfg.user, - password=cfg.password, - session_config=SessionConfig(graph=None), - session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000), - ) - try: - cls._ensure_space_exists(tmp_client, cfg) - finally: - tmp_client.close() - - conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None) - if conn_conf is None: - conn_conf = ConnectionConfig.from_defults( - cls._get_hosts_from_cfg(cfg), - getattr(cfg, "ssl_param", None), - ) - - sess_conf = SessionConfig(graph=getattr(cfg, "space", None)) - pool_conf = SessionPoolConfig( - size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000 - ) - - client = NebulaClient( - hosts=conn_conf.hosts, - username=cfg.user, - password=cfg.password, - conn_config=conn_conf, - session_config=sess_conf, - session_pool_config=pool_conf, - ) - cls._CLIENT_CACHE[key] = client - cls._CLIENT_REFCOUNT[key] = 0 - logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}") - - cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1 - - if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE: - try: - pass - finally: - pass - - if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE: - with cls._CLIENT_LOCK: - if key not in cls._CLIENT_INIT_DONE: - admin = cls._bootstrap_admin(cfg, client) - try: - admin._ensure_database_exists() - admin._create_basic_property_indexes() - admin._create_vector_index( - dimensions=int( - admin.embedding_dimension or admin.default_memory_dimension - ), - ) - cls._CLIENT_INIT_DONE.add(key) - logger.info("[NebulaGraphDBSync] One-time init done") - except Exception: - logger.exception("[NebulaGraphDBSync] One-time init failed") - - return key, client - - def _refresh_client(self): - """ - refresh NebulaClient: - """ - old_key = getattr(self, "_client_key", None) - if not old_key: - return - - cls = self.__class__ - with cls._CLIENT_LOCK: - try: - if old_key in cls._CLIENT_CACHE: - try: - cls._CLIENT_CACHE[old_key].close() - except Exception as e: - logger.warning(f"[refresh_client] close old client error: {e}") - finally: - cls._CLIENT_CACHE.pop(old_key, None) - finally: - cls._CLIENT_REFCOUNT[old_key] = 0 - - new_key, new_client = cls._get_or_create_shared_client(self.config) - self._client_key = new_key - self._client = new_client - logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}") - - @classmethod - def _release_shared_client(cls, key: str): - with cls._CLIENT_LOCK: - if key not in cls._CLIENT_CACHE: - return - cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1) - if cls._CLIENT_REFCOUNT[key] == 0: - try: - cls._CLIENT_CACHE[key].close() - except Exception as e: - logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}") - finally: - cls._CLIENT_CACHE.pop(key, None) - cls._CLIENT_REFCOUNT.pop(key, None) - logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}") - - @classmethod - def close_all_shared_clients(cls): - with cls._CLIENT_LOCK: - for key, client in list(cls._CLIENT_CACHE.items()): - try: - client.close() - except Exception as e: - logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}") - finally: - logger.info(f"[NebulaGraphDBSync] Closed client key={key}") - cls._CLIENT_CACHE.clear() - cls._CLIENT_REFCOUNT.clear() - - @require_python_package( - import_name="nebulagraph_python", - install_command="pip install nebulagraph-python>=5.1.1", - install_link=".....", - ) - def __init__(self, config: NebulaGraphDBConfig): - """ - NebulaGraph DB client initialization. - - Required config attributes: - - hosts: list[str] like ["host1:port", "host2:port"] - - user: str - - password: str - - db_name: str (optional for basic commands) - - Example config: - { - "hosts": ["xxx.xx.xx.xxx:xxxx"], - "user": "root", - "password": "nebula", - "space": "test" - } - """ - - assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED" - self.config = config - self.db_name = config.space - self.user_name = config.user_name - self.embedding_dimension = config.embedding_dimension - self.default_memory_dimension = 3072 - self.common_fields = { - "id", - "memory", - "user_name", - "user_id", - "session_id", - "status", - "key", - "confidence", - "tags", - "created_at", - "updated_at", - "memory_type", - "sources", - "source", - "node_type", - "visibility", - "usage", - "background", - } - self.base_fields = set(self.common_fields) - {"usage"} - self.heavy_fields = {"usage"} - self.dim_field = ( - f"embedding_{self.embedding_dimension}" - if (str(self.embedding_dimension) != str(self.default_memory_dimension)) - else "embedding" - ) - self.system_db_name = config.space - - # ---- NEW: pool acquisition strategy - # Get or create a shared pool from the class-level cache - self._client_key, self._client = self._get_or_create_shared_client(config) - self._owns_client = True - - logger.info("Connected to NebulaGraph successfully.") - - @timed - def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True): - def _wrap_use_db(q: str) -> str: - if auto_set_db and self.db_name: - return f"USE `{self.db_name}`\n{q}" - return q - - try: - return self._client.execute(_wrap_use_db(gql), timeout=timeout) - - except Exception as e: - emsg = str(e) - if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS): - logger.warning(f"[execute_query] {e!s} โ†’ refreshing session pool and retry once...") - try: - self._refresh_client() - return self._client.execute(_wrap_use_db(gql), timeout=timeout) - except Exception: - logger.exception("[execute_query] retry after refresh failed") - raise - raise - - @timed - def close(self): - """ - Close the connection resource if this instance owns it. - - - If pool was injected (`shared_pool`), do nothing. - - If pool was acquired via shared cache, decrement refcount and close - when the last owner releases it. - """ - if not self._owns_client: - logger.debug("[NebulaGraphDBSync] close() skipped (injected client).") - return - if self._client_key: - self._release_shared_client(self._client_key) - self._client_key = None - self._client = None - - # NOTE: __del__ is best-effort; do not rely on GC order. - def __del__(self): - with suppress(Exception): - self.close() - - @timed - def create_index( - self, - label: str = "Memory", - vector_property: str = "embedding", - dimensions: int = 3072, - index_name: str = "memory_vector_index", - ) -> None: - # Create vector index - self._create_vector_index(label, vector_property, dimensions, index_name) - # Create indexes - self._create_basic_property_indexes() - - @timed - def remove_oldest_memory( - self, memory_type: str, keep_latest: int, user_name: str | None = None - ) -> None: - """ - Remove all WorkingMemory nodes except the latest `keep_latest` entries. - - Args: - memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). - keep_latest (int): Number of latest WorkingMemory entries to keep. - user_name(str): optional user_name. - """ - try: - user_name = user_name if user_name else self.config.user_name - optional_condition = f"AND n.user_name = '{user_name}'" - count = self.count_nodes(memory_type, user_name) - if count > keep_latest: - delete_query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE n.memory_type = '{memory_type}' - {optional_condition} - ORDER BY n.updated_at DESC - OFFSET {int(keep_latest)} - DETACH DELETE n - """ - self.execute_query(delete_query) - except Exception as e: - logger.warning(f"Delete old mem error: {e}") - - @timed - def add_node( - self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None - ) -> None: - """ - Insert or update a Memory node in NebulaGraph. - """ - metadata["user_name"] = user_name if user_name else self.config.user_name - now = datetime.utcnow() - metadata = metadata.copy() - metadata.setdefault("created_at", now) - metadata.setdefault("updated_at", now) - metadata["node_type"] = metadata.pop("type") - metadata["id"] = id - metadata["memory"] = memory - - if "embedding" in metadata and isinstance(metadata["embedding"], list): - assert len(metadata["embedding"]) == self.embedding_dimension, ( - f"input embedding dimension must equal to {self.embedding_dimension}" - ) - embedding = metadata.pop("embedding") - metadata[self.dim_field] = _normalize(embedding) - - metadata = self._metadata_filter(metadata) - properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()) - gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})" - - try: - self.execute_query(gql) - logger.info("insert success") - except Exception as e: - logger.error( - f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}" - ) - - @timed - def node_not_exist(self, scope: str, user_name: str | None = None) -> int: - user_name = user_name if user_name else self.config.user_name - filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"' - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE {filter_clause} - RETURN n.id AS id - LIMIT 1 - """ - - try: - result = self.execute_query(query) - return result.size == 0 - except Exception as e: - logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) - raise - - @timed - def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: - """ - Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present. - """ - user_name = user_name if user_name else self.config.user_name - fields = fields.copy() - set_clauses = [] - for k, v in fields.items(): - set_clauses.append(f"n.{k} = {self._format_value(v, k)}") - - set_clause_str = ",\n ".join(set_clauses) - - query = f""" - MATCH (n@Memory {{id: "{id}"}}) - """ - query += f'WHERE n.user_name = "{user_name}"' - - query += f"\nSET {set_clause_str}" - self.execute_query(query) - - @timed - def delete_node(self, id: str, user_name: str | None = None) -> None: - """ - Delete a node from the graph. - Args: - id: Node identifier to delete. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)} - DETACH DELETE n - """ - self.execute_query(query) - - @timed - def add_edge(self, source_id: str, target_id: str, type: str, user_name: str | None = None): - """ - Create an edge from source node to target node. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type (e.g., 'RELATE_TO', 'PARENT'). - user_name (str, optional): User name for filtering in non-multi-db mode - """ - if not source_id or not target_id: - raise ValueError("[add_edge] source_id and target_id must be provided") - user_name = user_name if user_name else self.config.user_name - props = "" - props = f'{{user_name: "{user_name}"}}' - insert_stmt = f''' - MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) - INSERT (a) -[e@{type} {props}]-> (b) - ''' - try: - self.execute_query(insert_stmt) - except Exception as e: - logger.error(f"Failed to insert edge: {e}", exc_info=True) - - @timed - def delete_edge( - self, source_id: str, target_id: str, type: str, user_name: str | None = None - ) -> None: - """ - Delete a specific edge between two nodes. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type to remove. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (a@Memory) -[r@{type}]-> (b@Memory) - WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)} - """ - - query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" - query += "\nDELETE r" - self.execute_query(query) - - @timed - def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = "{memory_type}" - """ - query += f"\nAND n.user_name = '{user_name}'" - query += "\nRETURN COUNT(n) AS count" - - try: - result = self.execute_query(query) - return result.one_or_none()["count"].value - except Exception as e: - logger.error(f"[get_memory_count] Failed: {e}") - return -1 - - @timed - def count_nodes(self, scope: str, user_name: str | None = None) -> int: - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = "{scope}" - """ - query += f"\nAND n.user_name = '{user_name}'" - query += "\nRETURN count(n) AS count" - - result = self.execute_query(query) - return result.one_or_none()["count"].value - - @timed - def edge_exists( - self, - source_id: str, - target_id: str, - type: str = "ANY", - direction: str = "OUTGOING", - user_name: str | None = None, - ) -> bool: - """ - Check if an edge exists between two nodes. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type. Use "ANY" to match any relationship type. - direction: Direction of the edge. - Use "OUTGOING" (default), "INCOMING", or "ANY". - user_name (str, optional): User name for filtering in non-multi-db mode - Returns: - True if the edge exists, otherwise False. - """ - # Prepare the relationship pattern - user_name = user_name if user_name else self.config.user_name - rel = "r" if type == "ANY" else f"r@{type}" - - # Prepare the match pattern with direction - if direction == "OUTGOING": - pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]->(b@Memory {{id: '{target_id}'}})" - elif direction == "INCOMING": - pattern = f"(a@Memory {{id: '{source_id}'}})<-[{rel}]-(b@Memory {{id: '{target_id}'}})" - elif direction == "ANY": - pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]-(b@Memory {{id: '{target_id}'}})" - else: - raise ValueError( - f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." - ) - query = f"MATCH {pattern}" - query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" - query += "\nRETURN r" - - # Run the Cypher query - result = self.execute_query(query) - record = result.one_or_none() - if record is None: - return False - return record.values() is not None - - @timed - # Graph Query & Reasoning - def get_node( - self, id: str, include_embedding: bool = False, user_name: str | None = None - ) -> dict[str, Any] | None: - """ - Retrieve a Memory node by its unique ID. - - Args: - id (str): Node ID (Memory.id) - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - dict: Node properties as key-value pairs, or None if not found. - """ - filter_clause = f'n.id = "{id}"' - return_fields = self._build_return_fields(include_embedding) - gql = f""" - MATCH (n@Memory) - WHERE {filter_clause} - RETURN {return_fields} - """ - - try: - result = self.execute_query(gql) - for row in result: - props = {k: v.value for k, v in row.items()} - node = self._parse_node(props) - return node - - except Exception as e: - logger.error( - f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}" - ) - return None - - @timed - def get_nodes( - self, - ids: list[str], - include_embedding: bool = False, - user_name: str | None = None, - **kwargs, - ) -> list[dict[str, Any]]: - """ - Retrieve the metadata and memory of a list of nodes. - Args: - ids: List of Node identifier. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - Returns: - list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'. - - Notes: - - Assumes all provided IDs are valid and exist. - - Returns empty list if input is empty. - """ - if not ids: - return [] - # Safe formatting of the ID list - id_list = ",".join(f'"{_id}"' for _id in ids) - - return_fields = self._build_return_fields(include_embedding) - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE n.id IN [{id_list}] - RETURN {return_fields} - """ - nodes = [] - try: - results = self.execute_query(query) - for row in results: - props = {k: v.value for k, v in row.items()} - nodes.append(self._parse_node(props)) - except Exception as e: - logger.error( - f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}" - ) - return nodes - - @timed - def get_edges( - self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None - ) -> list[dict[str, str]]: - """ - Get edges connected to a node, with optional type and direction filter. - - Args: - id: Node ID to retrieve edges for. - type: Relationship type to match, or 'ANY' to match all. - direction: 'OUTGOING', 'INCOMING', or 'ANY'. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - List of edges: - [ - {"from": "source_id", "to": "target_id", "type": "RELATE"}, - ... - ] - """ - # Build relationship type filter - rel_type = "" if type == "ANY" else f"@{type}" - user_name = user_name if user_name else self.config.user_name - # Build Cypher pattern based on direction - if direction == "OUTGOING": - pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "INCOMING": - pattern = f"(a@Memory)<-[r{rel_type}]-(b@Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "ANY": - pattern = f"(a@Memory)-[r{rel_type}]-(b@Memory)" - where_clause = f"a.id = '{id}' OR b.id = '{id}'" - else: - raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - - where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" - - query = f""" - MATCH {pattern} - WHERE {where_clause} - RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type - """ - - result = self.execute_query(query) - edges = [] - for record in result: - edges.append( - { - "from": record["from_id"].value, - "to": record["to_id"].value, - "type": record["edge_type"].value, - } - ) - return edges - - @timed - def get_neighbors_by_tag( - self, - tags: list[str], - exclude_ids: list[str], - top_k: int = 5, - min_overlap: int = 1, - include_embedding: bool = False, - user_name: str | None = None, - ) -> list[dict[str, Any]]: - """ - Find top-K neighbor nodes with maximum tag overlap. - - Args: - tags: The list of tags to match. - exclude_ids: Node IDs to exclude (e.g., local cluster). - top_k: Max number of neighbors to return. - min_overlap: Minimum number of overlapping tags required. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - List of dicts with node details and overlap count. - """ - if not tags: - return [] - user_name = user_name if user_name else self.config.user_name - where_clauses = [ - 'n.status = "activated"', - 'NOT (n.node_type = "reasoning")', - 'NOT (n.memory_type = "WorkingMemory")', - ] - if exclude_ids: - where_clauses.append(f"NOT (n.id IN {exclude_ids})") - - where_clauses.append(f'n.user_name = "{user_name}"') - - where_clause = " AND ".join(where_clauses) - tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]" - - return_fields = self._build_return_fields(include_embedding) - query = f""" - LET tag_list = {tag_list_literal} - - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE {where_clause} - RETURN {return_fields}, - size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count - ORDER BY overlap_count DESC - LIMIT {top_k} - """ - - result = self.execute_query(query) - neighbors: list[dict[str, Any]] = [] - for r in result: - props = {k: v.value for k, v in r.items() if k != "overlap_count"} - parsed = self._parse_node(props) - parsed["overlap_count"] = r["overlap_count"].value - neighbors.append(parsed) - - neighbors.sort(key=lambda x: x["overlap_count"], reverse=True) - neighbors = neighbors[:top_k] - result = [] - for neighbor in neighbors[:top_k]: - neighbor.pop("overlap_count") - result.append(neighbor) - return result - - @timed - def get_children_with_embeddings( - self, id: str, user_name: str | None = None - ) -> list[dict[str, Any]]: - user_name = user_name if user_name else self.config.user_name - where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" - - query = f""" - MATCH (p@Memory)-[@PARENT]->(c@Memory) - WHERE p.id = "{id}" {where_user} - RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory - """ - result = self.execute_query(query) - children = [] - for row in result: - eid = row["id"].value # STRING - emb_v = row[self.dim_field].value # NVector - emb = list(emb_v.values) if emb_v else [] - mem = row["memory"].value # STRING - - children.append({"id": eid, "embedding": emb, "memory": mem}) - return children - - @timed - def get_subgraph( - self, - center_id: str, - depth: int = 2, - center_status: str = "activated", - user_name: str | None = None, - ) -> dict[str, Any]: - """ - Retrieve a local subgraph centered at a given node. - Args: - center_id: The ID of the center node. - depth: The hop distance for neighbors. - center_status: Required status for center node. - user_name (str, optional): User name for filtering in non-multi-db mode - Returns: - { - "core_node": {...}, - "neighbors": [...], - "edges": [...] - } - """ - if not 1 <= depth <= 5: - raise ValueError("depth must be 1-5") - - user_name = user_name if user_name else self.config.user_name - - gql = f""" - MATCH (center@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE center.id = '{center_id}' - AND center.status = '{center_status}' - AND center.user_name = '{user_name}' - OPTIONAL MATCH p = (center)-[e]->{{1,{depth}}}(neighbor@Memory) - WHERE neighbor.user_name = '{user_name}' - RETURN center, - collect(DISTINCT neighbor) AS neighbors, - collect(EDGES(p)) AS edge_chains - """ - - result = self.execute_query(gql).one_or_none() - if not result or result.size == 0: - return {"core_node": None, "neighbors": [], "edges": []} - - core_node_props = result["center"].as_node().get_properties() - core_node = self._parse_node(core_node_props) - neighbors = [] - vid_to_id_map = {result["center"].as_node().node_id: core_node["id"]} - for n in result["neighbors"].value: - n_node = n.as_node() - n_props = n_node.get_properties() - node_parsed = self._parse_node(n_props) - neighbors.append(node_parsed) - vid_to_id_map[n_node.node_id] = node_parsed["id"] - - edges = [] - for chain_group in result["edge_chains"].value: - for edge_wr in chain_group.value: - edge = edge_wr.value - edges.append( - { - "type": edge.get_type(), - "source": vid_to_id_map.get(edge.get_src_id()), - "target": vid_to_id_map.get(edge.get_dst_id()), - } - ) - - return {"core_node": core_node, "neighbors": neighbors, "edges": edges} - - @timed - # Search / recall operations - def search_by_embedding( - self, - vector: list[float], - top_k: int = 5, - scope: str | None = None, - status: str | None = None, - threshold: float | None = None, - search_filter: dict | None = None, - user_name: str | None = None, - **kwargs, - ) -> list[dict]: - """ - Retrieve node IDs based on vector similarity. - - Args: - vector (list[float]): The embedding vector representing query semantics. - top_k (int): Number of top similar nodes to retrieve. - scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory'). - status (str, optional): Node status filter (e.g., 'active', 'archived'). - If provided, restricts results to nodes with matching status. - threshold (float, optional): Minimum similarity score threshold (0 ~ 1). - search_filter (dict, optional): Additional metadata filters for search results. - Keys should match node properties, values are the expected values. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. - - Notes: - - This method uses Neo4j native vector indexing to search for similar nodes. - - If scope is provided, it restricts results to nodes with matching memory_type. - - If 'status' is provided, only nodes with the matching status will be returned. - - If threshold is provided, only results with score >= threshold will be returned. - - If search_filter is provided, additional WHERE clauses will be added for metadata filtering. - - Typical use case: restrict to 'status = activated' to avoid - matching archived or merged nodes. - """ - user_name = user_name if user_name else self.config.user_name - vector = _normalize(vector) - dim = len(vector) - vector_str = ",".join(f"{float(x)}" for x in vector) - gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])" - where_clauses = [f"n.{self.dim_field} IS NOT NULL"] - if scope: - where_clauses.append(f'n.memory_type = "{scope}"') - if status: - where_clauses.append(f'n.status = "{status}"') - where_clauses.append(f'n.user_name = "{user_name}"') - - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append(f'n.{key} = "{value}"') - else: - where_clauses.append(f"n.{key} = {value}") - - where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - gql = f""" - let a = {gql_vector} - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - {where_clause} - ORDER BY inner_product(n.{self.dim_field}, a) DESC - LIMIT {top_k} - RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score""" - try: - result = self.execute_query(gql) - except Exception as e: - logger.error(f"[search_by_embedding] Query failed: {e}") - return [] - - try: - output = [] - for row in result: - values = row.values() - id_val = values[0].as_string() - score_val = values[1].as_double() - score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score - if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) - return output - except Exception as e: - logger.error(f"[search_by_embedding] Result parse failed: {e}") - return [] - - @timed - def get_by_metadata( - self, filters: list[dict[str, Any]], user_name: str | None = None - ) -> list[str]: - """ - 1. ADD logic: "AND" vs "OR"(support logic combination); - 2. Support nested conditional expressions; - - Retrieve node IDs that match given metadata filters. - Supports exact match. - - Args: - filters: List of filter dicts like: - [ - {"field": "key", "op": "in", "value": ["A", "B"]}, - {"field": "confidence", "op": ">=", "value": 80}, - {"field": "tags", "op": "contains", "value": "AI"}, - ... - ] - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[str]: Node IDs whose metadata match the filter conditions. (AND logic). - - Notes: - - Supports structured querying such as tag/category/importance/time filtering. - - Can be used for faceted recall or prefiltering before embedding rerank. - """ - where_clauses = [] - user_name = user_name if user_name else self.config.user_name - for _i, f in enumerate(filters): - field = f["field"] - op = f.get("op", "=") - value = f["value"] - - escaped_value = self._format_value(value) - - # Build WHERE clause - if op == "=": - where_clauses.append(f"n.{field} = {escaped_value}") - elif op == "in": - where_clauses.append(f"n.{field} IN {escaped_value}") - elif op == "contains": - where_clauses.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0") - elif op == "starts_with": - where_clauses.append(f"n.{field} STARTS WITH {escaped_value}") - elif op == "ends_with": - where_clauses.append(f"n.{field} ENDS WITH {escaped_value}") - elif op in [">", ">=", "<", "<="]: - where_clauses.append(f"n.{field} {op} {escaped_value}") - else: - raise ValueError(f"Unsupported operator: {op}") - - where_clauses.append(f'n.user_name = "{user_name}"') - - where_str = " AND ".join(where_clauses) - gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id" - ids = [] - try: - result = self.execute_query(gql) - ids = [record["id"].value for record in result] - except Exception as e: - logger.error(f"Failed to get metadata: {e}, gql is {gql}") - return ids - - @timed - def get_grouped_counts( - self, - group_fields: list[str], - where_clause: str = "", - params: dict[str, Any] | None = None, - user_name: str | None = None, - ) -> list[dict[str, Any]]: - """ - Count nodes grouped by any fields. - - Args: - group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"] - where_clause (str, optional): Extra WHERE condition. E.g., - "WHERE n.status = 'activated'" - params (dict, optional): Parameters for WHERE clause. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] - """ - if not group_fields: - raise ValueError("group_fields cannot be empty") - user_name = user_name if user_name else self.config.user_name - # GQL-specific modifications - user_clause = f"n.user_name = '{user_name}'" - if where_clause: - where_clause = where_clause.strip() - if where_clause.upper().startswith("WHERE"): - where_clause += f" AND {user_clause}" - else: - where_clause = f"WHERE {where_clause} AND {user_clause}" - else: - where_clause = f"WHERE {user_clause}" - - # Inline parameters if provided - if params: - for key, value in params.items(): - # Handle different value types appropriately - if isinstance(value, str): - value = f"'{value}'" - where_clause = where_clause.replace(f"${key}", str(value)) - - return_fields = [] - group_by_fields = [] - - for field in group_fields: - alias = field.replace(".", "_") - return_fields.append(f"n.{field} AS {alias}") - group_by_fields.append(alias) - # Full GQL query construction - gql = f""" - MATCH (n /*+ INDEX(idx_memory_user_name) */) - {where_clause} - RETURN {", ".join(return_fields)}, COUNT(n) AS count - """ - result = self.execute_query(gql) # Pure GQL string execution - - output = [] - for record in result: - group_values = {} - for i, field in enumerate(group_fields): - value = record.values()[i].as_string() - group_values[field] = value - count_value = record["count"].value - output.append({**group_values, "count": count_value}) - - return output - - @timed - def clear(self, user_name: str | None = None) -> None: - """ - Clear the entire graph if the target database exists. - - Args: - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - try: - query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n" - self.execute_query(query) - logger.info("Cleared all nodes from database.") - - except Exception as e: - logger.error(f"[ERROR] Failed to clear database: {e}") - - @timed - def export_graph( - self, include_embedding: bool = False, user_name: str | None = None, **kwargs - ) -> dict[str, Any]: - """ - Export all graph nodes and edges in a structured form. - Args: - include_embedding (bool): Whether to include the large embedding field. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - { - "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], - "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] - } - """ - user_name = user_name if user_name else self.config.user_name - node_query = "MATCH (n@Memory)" - edge_query = "MATCH (a@Memory)-[r]->(b@Memory)" - node_query += f' WHERE n.user_name = "{user_name}"' - edge_query += f' WHERE r.user_name = "{user_name}"' - - try: - if include_embedding: - return_fields = "n" - else: - return_fields = ",".join( - [ - "n.id AS id", - "n.memory AS memory", - "n.user_name AS user_name", - "n.user_id AS user_id", - "n.session_id AS session_id", - "n.status AS status", - "n.key AS key", - "n.confidence AS confidence", - "n.tags AS tags", - "n.created_at AS created_at", - "n.updated_at AS updated_at", - "n.memory_type AS memory_type", - "n.sources AS sources", - "n.source AS source", - "n.node_type AS node_type", - "n.visibility AS visibility", - "n.usage AS usage", - "n.background AS background", - ] - ) - - full_node_query = f"{node_query} RETURN {return_fields}" - node_result = self.execute_query(full_node_query, timeout=20) - nodes = [] - logger.debug(f"Debugging: {node_result}") - for row in node_result: - if include_embedding: - props = row.values()[0].as_node().get_properties() - else: - props = {k: v.value for k, v in row.items()} - node = self._parse_node(props) - nodes.append(node) - except Exception as e: - raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e - - try: - full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge" - edge_result = self.execute_query(full_edge_query, timeout=20) - edges = [ - { - "source": row.values()[0].value, - "target": row.values()[1].value, - "type": row.values()[2].value, - } - for row in edge_result - ] - except Exception as e: - raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e - - return {"nodes": nodes, "edges": edges} - - @timed - def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: - """ - Import the entire graph from a serialized dictionary. - - Args: - data: A dictionary containing all nodes and edges to be loaded. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - for node in data.get("nodes", []): - try: - id, memory, metadata = _compose_node(node) - metadata["user_name"] = user_name - metadata = self._prepare_node_metadata(metadata) - metadata.update({"id": id, "memory": memory}) - properties = ", ".join( - f"{k}: {self._format_value(v, k)}" for k, v in metadata.items() - ) - node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})" - self.execute_query(node_gql) - except Exception as e: - logger.error(f"Fail to load node: {node}, error: {e}") - - for edge in data.get("edges", []): - try: - source_id, target_id = edge["source"], edge["target"] - edge_type = edge["type"] - props = f'{{user_name: "{user_name}"}}' - edge_gql = f''' - MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) - INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b) - ''' - self.execute_query(edge_gql) - except Exception as e: - logger.error(f"Fail to load edge: {edge}, error: {e}") - - @timed - def get_all_memory_items( - self, scope: str, include_embedding: bool = False, user_name: str | None = None - ) -> (list)[dict]: - """ - Retrieve all memory items of a specific memory_type. - - Args: - scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[dict]: Full list of memory items under this scope. - """ - user_name = user_name if user_name else self.config.user_name - if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: - raise ValueError(f"Unsupported memory type scope: {scope}") - - where_clause = f"WHERE n.memory_type = '{scope}'" - where_clause += f" AND n.user_name = '{user_name}'" - - return_fields = self._build_return_fields(include_embedding) - - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - {where_clause} - RETURN {return_fields} - LIMIT 100 - """ - nodes = [] - try: - results = self.execute_query(query) - for row in results: - props = {k: v.value for k, v in row.items()} - nodes.append(self._parse_node(props)) - except Exception as e: - logger.error(f"Failed to get memories: {e}") - return nodes - - @timed - def get_structure_optimization_candidates( - self, scope: str, include_embedding: bool = False, user_name: str | None = None - ) -> list[dict]: - """ - Find nodes that are likely candidates for structure optimization: - - Isolated nodes, nodes with empty background, or nodes with exactly one child. - - Plus: the child of any parent node that has exactly one child. - """ - user_name = user_name if user_name else self.config.user_name - where_clause = f''' - n.memory_type = "{scope}" - AND n.status = "activated" - ''' - where_clause += f' AND n.user_name = "{user_name}"' - - return_fields = self._build_return_fields(include_embedding) - return_fields += f", n.{self.dim_field} AS {self.dim_field}" - - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE {where_clause} - OPTIONAL MATCH (n)-[@PARENT]->(c@Memory) - OPTIONAL MATCH (p@Memory)-[@PARENT]->(n) - WHERE c IS NULL AND p IS NULL - RETURN {return_fields} - """ - - candidates = [] - node_ids = set() - try: - results = self.execute_query(query) - for row in results: - props = {k: v.value for k, v in row.items()} - node = self._parse_node(props) - node_id = node["id"] - if node_id not in node_ids: - candidates.append(node) - node_ids.add(node_id) - except Exception as e: - logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}") - return candidates - - @timed - def drop_database(self) -> None: - """ - Permanently delete the entire database this instance is using. - WARNING: This operation is destructive and cannot be undone. - """ - raise ValueError( - f"Refusing to drop protected database: `{self.db_name}` in " - f"Shared Database Multi-Tenant mode" - ) - - @timed - def detect_conflicts(self) -> list[tuple[str, str]]: - """ - Detect conflicting nodes based on logical or semantic inconsistency. - Returns: - A list of (node_id1, node_id2) tuples that conflict. - """ - raise NotImplementedError - - @timed - # Structure Maintenance - def deduplicate_nodes(self) -> None: - """ - Deduplicate redundant or semantically similar nodes. - This typically involves identifying nodes with identical or near-identical memory. - """ - raise NotImplementedError - - @timed - def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: - """ - Get the ordered context chain starting from a node, following a relationship type. - Args: - id: Starting node ID. - type: Relationship type to follow (e.g., 'FOLLOWS'). - Returns: - List of ordered node IDs in the chain. - """ - raise NotImplementedError - - @timed - def get_neighbors( - self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" - ) -> list[str]: - """ - Get connected node IDs in a specific direction and relationship type. - Args: - id: Source node ID. - type: Relationship type. - direction: Edge direction to follow ('out', 'in', or 'both'). - Returns: - List of neighboring node IDs. - """ - raise NotImplementedError - - @timed - def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: - """ - Get the path of nodes from source to target within a limited depth. - Args: - source_id: Starting node ID. - target_id: Target node ID. - max_depth: Maximum path length to traverse. - Returns: - Ordered list of node IDs along the path. - """ - raise NotImplementedError - - @timed - def merge_nodes(self, id1: str, id2: str) -> str: - """ - Merge two similar or duplicate nodes into one. - Args: - id1: First node ID. - id2: Second node ID. - Returns: - ID of the resulting merged node. - """ - raise NotImplementedError - - @classmethod - def _ensure_space_exists(cls, tmp_client, cfg): - """Lightweight check to ensure target graph (space) exists.""" - db_name = getattr(cfg, "space", None) - if not db_name: - logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.") - return - - try: - res = tmp_client.execute("SHOW GRAPHS") - existing = {row.values()[0].as_string() for row in res} - if db_name not in existing: - tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type") - logger.info(f"โœ… Graph `{db_name}` created before session binding.") - else: - logger.debug(f"Graph `{db_name}` already exists.") - except Exception: - logger.exception("[NebulaGraphDBSync] Failed to ensure space exists") - - @timed - def _ensure_database_exists(self): - graph_type_name = "MemOSBgeM3Type" - - check_type_query = "SHOW GRAPH TYPES" - result = self.execute_query(check_type_query, auto_set_db=False) - - type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result) - - if not type_exists: - create_tag = f""" - CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{ - NODE Memory (:MemoryTag {{ - id STRING, - memory STRING, - user_name STRING, - user_id STRING, - session_id STRING, - status STRING, - key STRING, - confidence FLOAT, - tags LIST, - created_at STRING, - updated_at STRING, - memory_type STRING, - sources LIST, - source STRING, - node_type STRING, - visibility STRING, - usage LIST, - background STRING, - {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>, - PRIMARY KEY(id) - }}), - EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory) - }} - """ - self.execute_query(create_tag, auto_set_db=False) - else: - describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name}" - desc_result = self.execute_query(describe_query, auto_set_db=False) - - memory_fields = [] - for row in desc_result: - field_name = row.values()[0].as_string() - memory_fields.append(field_name) - - if self.dim_field not in memory_fields: - alter_query = f""" - ALTER GRAPH TYPE {graph_type_name} {{ - ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }} - }} - """ - self.execute_query(alter_query, auto_set_db=False) - logger.info(f"โœ… Add new vector search {self.dim_field} to {graph_type_name}") - else: - logger.info(f"โœ… Graph Type {graph_type_name} already include {self.dim_field}") - - create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}" - try: - self.execute_query(create_graph, auto_set_db=False) - logger.info(f"โœ… Graph ``{self.db_name}`` is now the working graph.") - except Exception as e: - logger.error(f"โŒ Failed to create tag: {e} trace: {traceback.format_exc()}") - - @timed - def _create_vector_index( - self, - label: str = "Memory", - vector_property: str = "embedding", - dimensions: int = 3072, - index_name: str = "memory_vector_index", - ) -> None: - """ - Create a vector index for the specified property in the label. - """ - if str(dimensions) == str(self.default_memory_dimension): - index_name = f"idx_{vector_property}" - vector_name = vector_property - else: - index_name = f"idx_{vector_property}_{dimensions}" - vector_name = f"{vector_property}_{dimensions}" - - create_vector_index = f""" - CREATE VECTOR INDEX IF NOT EXISTS {index_name} - ON NODE {label}::{vector_name} - OPTIONS {{ - DIM: {dimensions}, - METRIC: IP, - TYPE: IVF, - NLIST: 100, - TRAINSIZE: 1000 - }} - FOR `{self.db_name}` - """ - self.execute_query(create_vector_index) - logger.info( - f"โœ… Ensure {label}::{vector_property} vector index {index_name} " - f"exists (DIM={dimensions})" - ) - - @timed - def _create_basic_property_indexes(self) -> None: - """ - Create standard B-tree indexes on status, memory_type, created_at - and updated_at fields. - Create standard B-tree indexes on user_name when use Shared Database - Multi-Tenant Mode. - """ - fields = [ - "status", - "memory_type", - "created_at", - "updated_at", - "user_name", - ] - - for field in fields: - index_name = f"idx_memory_{field}" - gql = f""" - CREATE INDEX IF NOT EXISTS {index_name} ON NODE Memory({field}) - FOR `{self.db_name}` - """ - try: - self.execute_query(gql) - logger.info(f"โœ… Created index: {index_name} on field {field}") - except Exception as e: - logger.error( - f"โŒ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}" - ) - - @timed - def _index_exists(self, index_name: str) -> bool: - """ - Check if an index with the given name exists. - """ - """ - Check if a vector index with the given name exists in NebulaGraph. - - Args: - index_name (str): The name of the index to check. - - Returns: - bool: True if the index exists, False otherwise. - """ - query = "SHOW VECTOR INDEXES" - try: - result = self.execute_query(query) - return any(row.values()[0].as_string() == index_name for row in result) - except Exception as e: - logger.error(f"[Nebula] Failed to check index existence: {e}") - return False - - @timed - def _parse_value(self, value: Any) -> Any: - """turn Nebula ValueWrapper to Python type""" - from nebulagraph_python.value_wrapper import ValueWrapper - - if value is None or (hasattr(value, "is_null") and value.is_null()): - return None - try: - prim = value.cast_primitive() if isinstance(value, ValueWrapper) else value - except Exception as e: - logger.warning(f"Error when decode Nebula ValueWrapper: {e}") - prim = value.cast() if isinstance(value, ValueWrapper) else value - - if isinstance(prim, ValueWrapper): - return self._parse_value(prim) - if isinstance(prim, list): - return [self._parse_value(v) for v in prim] - if type(prim).__name__ == "NVector": - return list(prim.values) - - return prim # already a Python primitive - - def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]: - parsed = {k: self._parse_value(v) for k, v in props.items()} - - for tf in ("created_at", "updated_at"): - if tf in parsed and parsed[tf] is not None: - parsed[tf] = _normalize_datetime(parsed[tf]) - - node_id = parsed.pop("id") - memory = parsed.pop("memory", "") - parsed.pop("user_name", None) - metadata = parsed - metadata["type"] = metadata.pop("node_type") - - if self.dim_field in metadata: - metadata["embedding"] = metadata.pop(self.dim_field) - - return {"id": node_id, "memory": memory, "metadata": metadata} - - @timed - def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """ - Ensure metadata has proper datetime fields and normalized types. - - - Fill `created_at` and `updated_at` if missing (in ISO 8601 format). - - Convert embedding to list of float if present. - """ - now = datetime.utcnow().isoformat() - metadata["node_type"] = metadata.pop("type") - - # Fill timestamps if missing - metadata.setdefault("created_at", now) - metadata.setdefault("updated_at", now) - - # Normalize embedding type - embedding = metadata.get("embedding") - if embedding and isinstance(embedding, list): - metadata.pop("embedding") - metadata[self.dim_field] = _normalize([float(x) for x in embedding]) - - return metadata - - @timed - def _format_value(self, val: Any, key: str = "") -> str: - from nebulagraph_python.py_data_types import NVector - - # None - if val is None: - return "NULL" - # bool - if isinstance(val, bool): - return "true" if val else "false" - # str - if isinstance(val, str): - return f'"{_escape_str(val)}"' - # num - elif isinstance(val, (int | float)): - return str(val) - # time - elif isinstance(val, datetime): - return f'datetime("{val.isoformat()}")' - # list - elif isinstance(val, list): - if key == self.dim_field: - dim = len(val) - joined = ",".join(str(float(x)) for x in val) - return f"VECTOR<{dim}, FLOAT>([{joined}])" - else: - return f"[{', '.join(self._format_value(v) for v in val)}]" - # NVector - elif isinstance(val, NVector): - if key == self.dim_field: - dim = len(val) - joined = ",".join(str(float(x)) for x in val) - return f"VECTOR<{dim}, FLOAT>([{joined}])" - else: - logger.warning("Invalid NVector") - # dict - if isinstance(val, dict): - j = json.dumps(val, ensure_ascii=False, separators=(",", ":")) - return f'"{_escape_str(j)}"' - else: - return f'"{_escape_str(str(val))}"' - - @timed - def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]: - """ - Filter and validate metadata dictionary against the Memory node schema. - - Removes keys not in schema. - - Warns if required fields are missing. - """ - - dim_fields = {self.dim_field} - - allowed_fields = self.common_fields | dim_fields - - missing_fields = allowed_fields - metadata.keys() - if missing_fields: - logger.info(f"Metadata missing required fields: {sorted(missing_fields)}") - - filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields} - - return filtered_metadata - - def _build_return_fields(self, include_embedding: bool = False) -> str: - fields = set(self.base_fields) - if include_embedding: - fields.add(self.dim_field) - return ", ".join(f"n.{f} AS {f}" for f in fields) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 816dcd86d..31aab9407 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -443,7 +443,7 @@ def remove_oldest_memory( ) user_name = user_name if user_name else self._get_config_value("user_name") - # Use actual OFFSET logic, consistent with nebular.py + # Use actual OFFSET logic for deterministic pruning # First find IDs to delete, then delete them select_query = f""" SELECT id FROM "{self.db_name}_graph"."Memory" @@ -3531,7 +3531,7 @@ def get_neighbors_by_tag_ccl( user_name = user_name if user_name else self._get_config_value("user_name") - # Build query conditions; keep consistent with nebular.py + # Build query conditions shared with other graph backends where_clauses = [ 'n.status = "activated"', 'NOT (n.node_type = "reasoning")', @@ -3584,7 +3584,7 @@ def get_neighbors_by_tag_ccl( # Add overlap_count result_fields.append("overlap_count agtype") result_fields_str = ", ".join(result_fields) - # Use Cypher query; keep consistent with nebular.py + # Use Cypher query to keep the graph query path aligned query = f""" SELECT * FROM ( SELECT * FROM cypher('{self.db_name}_graph', $$ diff --git a/src/memos/mem_os/client.py b/src/memos/mem_os/client.py deleted file mode 100644 index f4a591e59..000000000 --- a/src/memos/mem_os/client.py +++ /dev/null @@ -1,5 +0,0 @@ -# TODO: @Li Ji - - -class ClientMOS: - pass diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py deleted file mode 100644 index b2c74c384..000000000 --- a/src/memos/mem_os/product.py +++ /dev/null @@ -1,1610 +0,0 @@ -import asyncio -import json -import os -import random -import time - -from collections.abc import Generator -from datetime import datetime -from typing import Any, Literal - -from dotenv import load_dotenv -from transformers import AutoTokenizer - -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig -from memos.context.context import ContextThread -from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube -from memos.mem_os.core import MOSCore -from memos.mem_os.utils.format_utils import ( - clean_json_response, - convert_graph_to_tree_forworkmem, - ensure_unique_tree_ids, - filter_nodes_by_tree_ids, - remove_embedding_recursive, - sort_children_by_memory_type, -) -from memos.mem_os.utils.reference_utils import ( - prepare_reference_data, - process_streaming_references_complete, -) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.schemas.task_schemas import ( - ANSWER_TASK_LABEL, - QUERY_TASK_LABEL, -) -from memos.mem_user.persistent_factory import PersistentUserManagerFactory -from memos.mem_user.user_manager import UserRole -from memos.memories.textual.item import ( - TextualMemoryItem, -) -from memos.templates.mos_prompts import ( - FURTHER_SUGGESTION_PROMPT, - SUGGESTION_QUERY_PROMPT_EN, - SUGGESTION_QUERY_PROMPT_ZH, - get_memos_prompt, -) -from memos.types import MessageList -from memos.utils import timed - - -logger = get_logger(__name__) - -load_dotenv() - -CUBE_PATH = os.getenv("MOS_CUBE_PATH", "/tmp/data/") - - -def _short_id(mem_id: str) -> str: - return (mem_id or "").split("-")[0] if mem_id else "" - - -def _format_mem_block(memories_all, max_items: int = 20, max_chars_each: int = 320) -> str: - """ - Modify TextualMemoryItem Format: - 1:abcd :: [P] text... - 2:ef01 :: [O] text... - sequence is [i:memId] i; [P]=PersonalMemory / [O]=OuterMemory - """ - if not memories_all: - return "(none)", "(none)" - - lines_o = [] - lines_p = [] - for idx, m in enumerate(memories_all[:max_items], 1): - mid = _short_id(getattr(m, "id", "") or "") - mtype = getattr(getattr(m, "metadata", {}), "memory_type", None) or getattr( - m, "metadata", {} - ).get("memory_type", "") - tag = "O" if "Outer" in str(mtype) else "P" - txt = (getattr(m, "memory", "") or "").replace("\n", " ").strip() - if len(txt) > max_chars_each: - txt = txt[: max_chars_each - 1] + "โ€ฆ" - mid = mid or f"mem_{idx}" - if tag == "O": - lines_o.append(f"[{idx}:{mid}] :: [{tag}] {txt}\n") - elif tag == "P": - lines_p.append(f"[{idx}:{mid}] :: [{tag}] {txt}") - return "\n".join(lines_o), "\n".join(lines_p) - - -class MOSProduct(MOSCore): - """ - The MOSProduct class inherits from MOSCore and manages multiple users. - Each user has their own configuration and cube access, but shares the same model instances. - """ - - def __init__( - self, - default_config: MOSConfig | None = None, - max_user_instances: int = 1, - default_cube_config: GeneralMemCubeConfig | None = None, - online_bot=None, - error_bot=None, - ): - """ - Initialize MOSProduct with an optional default configuration. - - Args: - default_config (MOSConfig | None): Default configuration for new users - max_user_instances (int): Maximum number of user instances to keep in memory - default_cube_config (GeneralMemCubeConfig | None): Default cube configuration for loading cubes - online_bot: DingDing online_bot function or None if disabled - error_bot: DingDing error_bot function or None if disabled - """ - # Initialize with a root config for shared resources - if default_config is None: - # Create a minimal config for root user - root_config = MOSConfig( - user_id="root", - session_id="root_session", - chat_model=default_config.chat_model if default_config else None, - mem_reader=default_config.mem_reader if default_config else None, - enable_mem_scheduler=default_config.enable_mem_scheduler - if default_config - else False, - mem_scheduler=default_config.mem_scheduler if default_config else None, - ) - else: - root_config = default_config.model_copy(deep=True) - root_config.user_id = "root" - root_config.session_id = "root_session" - - # Create persistent user manager BEFORE calling parent constructor - persistent_user_manager_client = PersistentUserManagerFactory.from_config( - config_factory=root_config.user_manager - ) - - # Initialize parent MOSCore with root config and persistent user manager - super().__init__(root_config, user_manager=persistent_user_manager_client) - - # Product-specific attributes - self.default_config = default_config - self.default_cube_config = default_cube_config - self.max_user_instances = max_user_instances - self.online_bot = online_bot - self.error_bot = error_bot - - # User-specific data structures - self.user_configs: dict[str, MOSConfig] = {} - self.user_cube_access: dict[str, set[str]] = {} # user_id -> set of cube_ids - self.user_chat_histories: dict[str, dict] = {} - - # Note: self.user_manager is now the persistent user manager from parent class - # No need for separate global_user_manager as they are the same instance - - # Initialize tiktoken for streaming - try: - # Use gpt2 encoding which is more stable and widely compatible - self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") - logger.info("tokenizer initialized successfully for streaming") - except Exception as e: - logger.warning( - f"Failed to initialize tokenizer, will use character-based chunking: {e}" - ) - self.tokenizer = None - - # Restore user instances from persistent storage - self._restore_user_instances(default_cube_config=default_cube_config) - logger.info(f"User instances restored successfully, now user is {self.mem_cubes.keys()}") - - def _restore_user_instances( - self, default_cube_config: GeneralMemCubeConfig | None = None - ) -> None: - """Restore user instances from persistent storage after service restart. - - Args: - default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None. - """ - try: - # Get all user configurations from persistent storage - user_configs = self.user_manager.list_user_configs(self.max_user_instances) - - # Get the raw database records for sorting by updated_at - session = self.user_manager._get_session() - try: - from memos.mem_user.persistent_user_manager import UserConfig - - db_configs = session.query(UserConfig).limit(self.max_user_instances).all() - # Create a mapping of user_id to updated_at timestamp - updated_at_map = {config.user_id: config.updated_at for config in db_configs} - - # Sort by updated_at timestamp (most recent first) and limit by max_instances - sorted_configs = sorted( - user_configs.items(), key=lambda x: updated_at_map.get(x[0], ""), reverse=True - )[: self.max_user_instances] - finally: - session.close() - - for user_id, config in sorted_configs: - if user_id != "root": # Skip root user - try: - # Store user config and cube access - self.user_configs[user_id] = config - self._load_user_cube_access(user_id) - - # Pre-load all cubes for this user with default config - self._preload_user_cubes(user_id, default_cube_config) - - logger.info( - f"Restored user configuration and pre-loaded cubes for {user_id}" - ) - - except Exception as e: - logger.error(f"Failed to restore user configuration for {user_id}: {e}") - - except Exception as e: - logger.error(f"Error during user instance restoration: {e}") - - def _initialize_cube_from_default_config( - self, cube_id: str, user_id: str, default_config: GeneralMemCubeConfig - ) -> GeneralMemCube | None: - """ - Initialize a cube from default configuration when cube path doesn't exist. - - Args: - cube_id (str): The cube ID to initialize. - user_id (str): The user ID for the cube. - default_config (GeneralMemCubeConfig): The default configuration to use. - """ - cube_config = default_config.model_copy(deep=True) - # Safely modify the graph_db user_name if it exists - if cube_config.text_mem.config.graph_db.config: - cube_config.text_mem.config.graph_db.config.user_name = ( - f"memos{user_id.replace('-', '')}" - ) - mem_cube = GeneralMemCube(config=cube_config) - return mem_cube - - def _preload_user_cubes( - self, user_id: str, default_cube_config: GeneralMemCubeConfig | None = None - ) -> None: - """Pre-load all cubes for a user into memory. - - Args: - user_id (str): The user ID to pre-load cubes for. - default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None. - """ - try: - # Get user's accessible cubes from persistent storage - accessible_cubes = self.user_manager.get_user_cubes(user_id) - - for cube in accessible_cubes: - if cube.cube_id not in self.mem_cubes: - try: - if cube.cube_path and os.path.exists(cube.cube_path): - # Pre-load cube with all memory types and default config - self.register_mem_cube( - cube.cube_path, - cube.cube_id, - user_id, - memory_types=["act_mem"] - if self.config.enable_activation_memory - else [], - default_config=default_cube_config, - ) - logger.info(f"Pre-loaded cube {cube.cube_id} for user {user_id}") - else: - logger.warning( - f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}, skipping pre-load" - ) - except Exception as e: - logger.error( - f"Failed to pre-load cube {cube.cube_id} for user {user_id}: {e}", - exc_info=True, - ) - - except Exception as e: - logger.error(f"Error pre-loading cubes for user {user_id}: {e}", exc_info=True) - - @timed - def _load_user_cubes( - self, user_id: str, default_cube_config: GeneralMemCubeConfig | None = None - ) -> None: - """Load all cubes for a user into memory. - - Args: - user_id (str): The user ID to load cubes for. - default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None. - """ - # Get user's accessible cubes from persistent storage - accessible_cubes = self.user_manager.get_user_cubes(user_id) - - for cube in accessible_cubes[:1]: - if cube.cube_id not in self.mem_cubes: - try: - if cube.cube_path and os.path.exists(cube.cube_path): - # Use MOSCore's register_mem_cube method directly with default config - # Only load act_mem since text_mem is stored in database - self.register_mem_cube( - cube.cube_path, - cube.cube_id, - user_id, - memory_types=["act_mem"], - default_config=default_cube_config, - ) - else: - logger.warning( - f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}, now init by default config" - ) - cube_obj = self._initialize_cube_from_default_config( - cube_id=cube.cube_id, - user_id=user_id, - default_config=default_cube_config, - ) - if cube_obj: - self.register_mem_cube( - cube_obj, - cube.cube_id, - user_id, - memory_types=[], - ) - else: - raise ValueError( - f"Failed to initialize default cube {cube.cube_id} for user {user_id}" - ) - except Exception as e: - logger.error(f"Failed to load cube {cube.cube_id} for user {user_id}: {e}") - logger.info(f"load user {user_id} cubes successfully") - - def _ensure_user_instance(self, user_id: str, max_instances: int | None = None) -> None: - """ - Ensure user configuration exists, creating it if necessary. - - Args: - user_id (str): The user ID - max_instances (int): Maximum instances to keep in memory (overrides class default) - """ - if user_id in self.user_configs: - return - - # Try to get config from persistent storage first - stored_config = self.user_manager.get_user_config(user_id) - if stored_config: - self.user_configs[user_id] = stored_config - self._load_user_cube_access(user_id) - else: - # Use default config - if not self.default_config: - raise ValueError(f"No configuration available for user {user_id}") - user_config = self.default_config.model_copy(deep=True) - user_config.user_id = user_id - user_config.session_id = f"{user_id}_session" - self.user_configs[user_id] = user_config - self._load_user_cube_access(user_id) - - # Apply LRU eviction if needed - max_instances = max_instances or self.max_user_instances - if len(self.user_configs) > max_instances: - # Remove least recently used instance (excluding root) - user_ids = [uid for uid in self.user_configs if uid != "root"] - if user_ids: - oldest_user_id = user_ids[0] - del self.user_configs[oldest_user_id] - if oldest_user_id in self.user_cube_access: - del self.user_cube_access[oldest_user_id] - logger.info(f"Removed least recently used user configuration: {oldest_user_id}") - - def _load_user_cube_access(self, user_id: str) -> None: - """Load user's cube access permissions.""" - try: - # Get user's accessible cubes from persistent storage - accessible_cubes = self.user_manager.get_user_cube_access(user_id) - self.user_cube_access[user_id] = set(accessible_cubes) - except Exception as e: - logger.warning(f"Failed to load cube access for user {user_id}: {e}") - self.user_cube_access[user_id] = set() - - def _get_user_config(self, user_id: str) -> MOSConfig: - """Get user configuration.""" - if user_id not in self.user_configs: - self._ensure_user_instance(user_id) - return self.user_configs[user_id] - - def _validate_user_cube_access(self, user_id: str, cube_id: str) -> None: - """Validate user has access to the cube.""" - if user_id not in self.user_cube_access: - self._load_user_cube_access(user_id) - - if cube_id not in self.user_cube_access.get(user_id, set()): - raise ValueError(f"User '{user_id}' does not have access to cube '{cube_id}'") - - def _validate_user_access(self, user_id: str, cube_id: str | None = None) -> None: - """Validate user access using MOSCore's built-in validation.""" - # Use MOSCore's built-in user validation - if cube_id: - self._validate_cube_access(user_id, cube_id) - else: - self._validate_user_exists(user_id) - - def _create_user_config(self, user_id: str, config: MOSConfig) -> MOSConfig: - """Create a new user configuration.""" - # Create a copy of config with the specific user_id - user_config = config.model_copy(deep=True) - user_config.user_id = user_id - user_config.session_id = f"{user_id}_session" - - # Save configuration to persistent storage - self.user_manager.save_user_config(user_id, user_config) - - return user_config - - def _get_or_create_user_config( - self, user_id: str, config: MOSConfig | None = None - ) -> MOSConfig: - """Get existing user config or create a new one.""" - if user_id in self.user_configs: - return self.user_configs[user_id] - - # Try to get config from persistent storage first - stored_config = self.user_manager.get_user_config(user_id) - if stored_config: - return self._create_user_config(user_id, stored_config) - - # Use provided config or default config - user_config = config or self.default_config - if not user_config: - raise ValueError(f"No configuration provided for user {user_id}") - - return self._create_user_config(user_id, user_config) - - def _build_system_prompt( - self, - memories_all: list[TextualMemoryItem], - base_prompt: str | None = None, - tone: str = "friendly", - verbosity: str = "mid", - ) -> str: - """ - Build custom system prompt for the user with memory references. - - Args: - user_id (str): The user ID. - memories (list[TextualMemoryItem]): The memories to build the system prompt. - - Returns: - str: The custom system prompt. - """ - # Build base prompt - # Add memory context if available - now = datetime.now() - formatted_date = now.strftime("%Y-%m-%d (%A)") - sys_body = get_memos_prompt( - date=formatted_date, tone=tone, verbosity=verbosity, mode="base" - ) - mem_block_o, mem_block_p = _format_mem_block(memories_all) - mem_block = mem_block_o + "\n" + mem_block_p - prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" - return ( - prefix - + sys_body - + "\n\n# Memories\n## PersonalMemory & OuterMemory (ordered)\n" - + mem_block - ) - - def _build_base_system_prompt( - self, - base_prompt: str | None = None, - tone: str = "friendly", - verbosity: str = "mid", - mode: str = "enhance", - ) -> str: - """ - Build base system prompt without memory references. - """ - now = datetime.now() - formatted_date = now.strftime("%Y-%m-%d (%A)") - sys_body = get_memos_prompt(date=formatted_date, tone=tone, verbosity=verbosity, mode=mode) - prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" - return prefix + sys_body - - def _build_memory_context( - self, - memories_all: list[TextualMemoryItem], - mode: str = "enhance", - ) -> str: - """ - Build memory context to be included in user message. - """ - if not memories_all: - return "" - - mem_block_o, mem_block_p = _format_mem_block(memories_all) - - if mode == "enhance": - return ( - "# Memories\n## PersonalMemory (ordered)\n" - + mem_block_p - + "\n## OuterMemory (ordered)\n" - + mem_block_o - + "\n\n" - ) - else: - mem_block = mem_block_o + "\n" + mem_block_p - return "# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + mem_block + "\n\n" - - def _build_enhance_system_prompt( - self, - user_id: str, - memories_all: list[TextualMemoryItem], - tone: str = "friendly", - verbosity: str = "mid", - ) -> str: - """ - Build enhance prompt for the user with memory references. - [DEPRECATED] Use _build_base_system_prompt and _build_memory_context instead. - """ - now = datetime.now() - formatted_date = now.strftime("%Y-%m-%d (%A)") - sys_body = get_memos_prompt( - date=formatted_date, tone=tone, verbosity=verbosity, mode="enhance" - ) - mem_block_o, mem_block_p = _format_mem_block(memories_all) - return ( - sys_body - + "\n\n# Memories\n## PersonalMemory (ordered)\n" - + mem_block_p - + "\n## OuterMemory (ordered)\n" - + mem_block_o - ) - - def _extract_references_from_response(self, response: str) -> tuple[str, list[dict]]: - """ - Extract reference information from the response and return clean text. - - Args: - response (str): The complete response text. - - Returns: - tuple[str, list[dict]]: A tuple containing: - - clean_text: Text with reference markers removed - - references: List of reference information - """ - import re - - try: - references = [] - # Pattern to match [refid:memoriesID] - pattern = r"\[(\d+):([^\]]+)\]" - - matches = re.findall(pattern, response) - for ref_number, memory_id in matches: - references.append({"memory_id": memory_id, "reference_number": int(ref_number)}) - - # Remove all reference markers from the text to get clean text - clean_text = re.sub(pattern, "", response) - - # Clean up any extra whitespace that might be left after removing markers - clean_text = re.sub(r"\s+", " ", clean_text).strip() - - return clean_text, references - except Exception as e: - logger.error(f"Error extracting references from response: {e}", exc_info=True) - return response, [] - - def _extract_struct_data_from_history(self, chat_data: list[dict]) -> dict: - """ - get struct message from chat-history - # TODO: @xcy make this more general - """ - system_content = "" - memory_content = "" - chat_history = [] - - for item in chat_data: - role = item.get("role") - content = item.get("content", "") - if role == "system": - parts = content.split("# Memories", 1) - system_content = parts[0].strip() - if len(parts) > 1: - memory_content = "# Memories" + parts[1].strip() - elif role in ("user", "assistant"): - chat_history.append({"role": role, "content": content}) - - if chat_history and chat_history[-1]["role"] == "assistant": - if len(chat_history) >= 2 and chat_history[-2]["role"] == "user": - chat_history = chat_history[:-2] - else: - chat_history = chat_history[:-1] - - return {"system": system_content, "memory": memory_content, "chat_history": chat_history} - - def _chunk_response_with_tiktoken( - self, response: str, chunk_size: int = 5 - ) -> Generator[str, None, None]: - """ - Chunk response using tiktoken for proper token-based streaming. - - Args: - response (str): The response text to chunk. - chunk_size (int): Number of tokens per chunk. - - Yields: - str: Chunked text pieces. - """ - if self.tokenizer: - # Use tiktoken for proper token-based chunking - tokens = self.tokenizer.encode(response) - - for i in range(0, len(tokens), chunk_size): - token_chunk = tokens[i : i + chunk_size] - chunk_text = self.tokenizer.decode(token_chunk) - yield chunk_text - else: - # Fallback to character-based chunking - char_chunk_size = chunk_size * 4 # Approximate character to token ratio - for i in range(0, len(response), char_chunk_size): - yield response[i : i + char_chunk_size] - - def _send_message_to_scheduler( - self, - user_id: str, - mem_cube_id: str, - query: str, - label: str, - ): - """ - Send message to scheduler. - args: - user_id: str, - mem_cube_id: str, - query: str, - """ - - if self.enable_mem_scheduler and (self.mem_scheduler is not None): - message_item = ScheduleMessageItem( - user_id=user_id, - mem_cube_id=mem_cube_id, - label=label, - content=query, - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) - - async def _post_chat_processing( - self, - user_id: str, - cube_id: str, - query: str, - full_response: str, - system_prompt: str, - time_start: float, - time_end: float, - speed_improvement: float, - current_messages: list, - ) -> None: - """ - Asynchronous processing of logs, notifications and memory additions - """ - try: - logger.info( - f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}" - ) - logger.info(f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}") - - clean_response, extracted_references = self._extract_references_from_response( - full_response - ) - struct_message = self._extract_struct_data_from_history(current_messages) - logger.info(f"Extracted {len(extracted_references)} references from response") - - # Send chat report notifications asynchronously - if self.online_bot: - logger.info("Online Bot Open!") - try: - from memos.memos_tools.notification_utils import ( - send_online_bot_notification_async, - ) - - # Prepare notification data - chat_data = {"query": query, "user_id": user_id, "cube_id": cube_id} - chat_data.update( - { - "memory": struct_message["memory"], - "chat_history": struct_message["chat_history"], - "full_response": full_response, - } - ) - - system_data = { - "references": extracted_references, - "time_start": time_start, - "time_end": time_end, - "speed_improvement": speed_improvement, - } - - emoji_config = {"chat": "๐Ÿ’ฌ", "system_info": "๐Ÿ“Š"} - - await send_online_bot_notification_async( - online_bot=self.online_bot, - header_name="MemOS Chat Report", - sub_title_name="chat_with_references", - title_color="#00956D", - other_data1=chat_data, - other_data2=system_data, - emoji=emoji_config, - ) - except Exception as e: - logger.warning(f"Failed to send chat notification (async): {e}") - - self._send_message_to_scheduler( - user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_TASK_LABEL - ) - - self.add( - user_id=user_id, - messages=[ - { - "role": "user", - "content": query, - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - { - "role": "assistant", - "content": clean_response, # Store clean text without reference markers - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - ], - mem_cube_id=cube_id, - ) - - logger.info(f"Post-chat processing completed for user {user_id}") - - except Exception as e: - logger.error(f"Error in post-chat processing for user {user_id}: {e}", exc_info=True) - - def _start_post_chat_processing( - self, - user_id: str, - cube_id: str, - query: str, - full_response: str, - system_prompt: str, - time_start: float, - time_end: float, - speed_improvement: float, - current_messages: list, - ) -> None: - """ - Asynchronous processing of logs, notifications and memory additions, handle synchronous and asynchronous environments - """ - logger.info("Start post_chat_processing...") - - def run_async_in_thread(): - """Running asynchronous tasks in a new thread""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - self._post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=full_response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=speed_improvement, - current_messages=current_messages, - ) - ) - finally: - loop.close() - except Exception as e: - logger.error( - f"Error in thread-based post-chat processing for user {user_id}: {e}", - exc_info=True, - ) - - try: - # Try to get the current event loop - asyncio.get_running_loop() - # Create task and store reference to prevent garbage collection - task = asyncio.create_task( - self._post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=full_response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=speed_improvement, - current_messages=current_messages, - ) - ) - # Add exception handling for the background task - task.add_done_callback( - lambda t: ( - logger.error( - f"Error in background post-chat processing for user {user_id}: {t.exception()}", - exc_info=True, - ) - if t.exception() - else None - ) - ) - except RuntimeError: - # No event loop, run in a new thread with context propagation - thread = ContextThread( - target=run_async_in_thread, - name=f"PostChatProcessing-{user_id}", - # Set as a daemon thread to avoid blocking program exit - daemon=True, - ) - thread.start() - - def _filter_memories_by_threshold( - self, - memories: list[TextualMemoryItem], - threshold: float = 0.30, - min_num: int = 3, - memory_type: Literal["OuterMemory"] = "OuterMemory", - ) -> list[TextualMemoryItem]: - """ - Filter memories by threshold and type, at least min_num memories for Non-OuterMemory. - Args: - memories: list[TextualMemoryItem], - threshold: float, - min_num: int, - memory_type: Literal["OuterMemory"], - Returns: - list[TextualMemoryItem] - """ - sorted_memories = sorted(memories, key=lambda m: m.metadata.relativity, reverse=True) - filtered_person = [m for m in memories if m.metadata.memory_type != memory_type] - filtered_outer = [m for m in memories if m.metadata.memory_type == memory_type] - filtered = [] - per_memory_count = 0 - for m in sorted_memories: - if m.metadata.relativity >= threshold: - if m.metadata.memory_type != memory_type: - per_memory_count += 1 - filtered.append(m) - if len(filtered) < min_num: - filtered = filtered_person[:min_num] + filtered_outer[:min_num] - else: - if per_memory_count < min_num: - filtered += filtered_person[per_memory_count:min_num] - filtered_memory = sorted(filtered, key=lambda m: m.metadata.relativity, reverse=True) - return filtered_memory - - def register_mem_cube( - self, - mem_cube_name_or_path_or_object: str | GeneralMemCube, - mem_cube_id: str | None = None, - user_id: str | None = None, - memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None, - default_config: GeneralMemCubeConfig | None = None, - ) -> None: - """ - Register a MemCube with the MOS. - - Args: - mem_cube_name_or_path_or_object (str | GeneralMemCube): The name, path, or GeneralMemCube object to register. - mem_cube_id (str, optional): The identifier for the MemCube. If not provided, a default ID is used. - user_id (str, optional): The user ID to register the cube for. - memory_types (list[str], optional): List of memory types to load. - If None, loads all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] - default_config (GeneralMemCubeConfig, optional): Default configuration for the cube. - """ - # Handle different input types - if isinstance(mem_cube_name_or_path_or_object, GeneralMemCube): - # Direct GeneralMemCube object provided - mem_cube = mem_cube_name_or_path_or_object - if mem_cube_id is None: - mem_cube_id = f"cube_{id(mem_cube)}" # Generate a unique ID - else: - # String path provided - mem_cube_name_or_path = mem_cube_name_or_path_or_object - if mem_cube_id is None: - mem_cube_id = mem_cube_name_or_path - - if mem_cube_id in self.mem_cubes: - logger.info(f"MemCube with ID {mem_cube_id} already in MOS, skip install.") - return - - # Create MemCube from path - time_start = time.time() - if os.path.exists(mem_cube_name_or_path): - mem_cube = GeneralMemCube.init_from_dir( - mem_cube_name_or_path, memory_types, default_config - ) - logger.info( - f"time register_mem_cube: init_from_dir time is: {time.time() - time_start}" - ) - else: - logger.warning( - f"MemCube {mem_cube_name_or_path} does not exist, try to init from remote repo." - ) - mem_cube = GeneralMemCube.init_from_remote_repo( - mem_cube_name_or_path, memory_types=memory_types, default_config=default_config - ) - - # Register the MemCube - logger.info( - f"Registering MemCube {mem_cube_id} with cube config {mem_cube.config.model_dump(mode='json')}" - ) - time_start = time.time() - self.mem_cubes[mem_cube_id] = mem_cube - time_end = time.time() - logger.info(f"time register_mem_cube: add mem_cube time is: {time_end - time_start}") - - def user_register( - self, - user_id: str, - user_name: str | None = None, - config: MOSConfig | None = None, - interests: str | None = None, - default_mem_cube: GeneralMemCube | None = None, - default_cube_config: GeneralMemCubeConfig | None = None, - mem_cube_id: str | None = None, - ) -> dict[str, str]: - """Register a new user with configuration and default cube. - - Args: - user_id (str): The user ID for registration. - user_name (str): The user name for registration. - config (MOSConfig | None, optional): User-specific configuration. Defaults to None. - interests (str | None, optional): User interests as string. Defaults to None. - default_mem_cube (GeneralMemCube | None, optional): Default memory cube. Defaults to None. - default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None. - - Returns: - dict[str, str]: Registration result with status and message. - """ - try: - # Use provided config or default config - user_config = config or self.default_config - if not user_config: - return { - "status": "error", - "message": "No configuration provided for user registration", - } - if not user_name: - user_name = user_id - - # Create user with configuration using persistent user manager - self.user_manager.create_user_with_config(user_id, user_config, UserRole.USER, user_id) - - # Create user configuration - user_config = self._create_user_config(user_id, user_config) - - # Create a default cube for the user using MOSCore's methods - default_cube_name = f"{user_name}_{user_id}_default_cube" - mem_cube_name_or_path = os.path.join(CUBE_PATH, default_cube_name) - default_cube_id = self.create_cube_for_user( - cube_name=default_cube_name, - owner_id=user_id, - cube_path=mem_cube_name_or_path, - cube_id=mem_cube_id, - ) - time_start = time.time() - if default_mem_cube: - try: - default_mem_cube.dump(mem_cube_name_or_path, memory_types=[]) - except Exception as e: - logger.error(f"Failed to dump default cube: {e}") - time_end = time.time() - logger.info(f"time user_register: dump default cube time is: {time_end - time_start}") - # Register the default cube with MOS - self.register_mem_cube( - mem_cube_name_or_path_or_object=default_mem_cube, - mem_cube_id=default_cube_id, - user_id=user_id, - memory_types=["act_mem"] if self.config.enable_activation_memory else [], - default_config=default_cube_config, # use default cube config - ) - - # Add interests to the default cube if provided - if interests: - self.add(memory_content=interests, mem_cube_id=default_cube_id, user_id=user_id) - - return { - "status": "success", - "message": f"User {user_name} registered successfully with default cube {default_cube_id}", - "user_id": user_id, - "default_cube_id": default_cube_id, - } - - except Exception as e: - return {"status": "error", "message": f"Failed to register user: {e!s}"} - - def _get_further_suggestion(self, message: MessageList | None = None) -> list[str]: - """Get further suggestion prompt.""" - try: - dialogue_info = "\n".join([f"{msg['role']}: {msg['content']}" for msg in message[-2:]]) - further_suggestion_prompt = FURTHER_SUGGESTION_PROMPT.format(dialogue=dialogue_info) - message_list = [{"role": "system", "content": further_suggestion_prompt}] - response = self.chat_llm.generate(message_list) - clean_response = clean_json_response(response) - response_json = json.loads(clean_response) - return response_json["query"] - except Exception as e: - logger.error(f"Error getting further suggestion: {e}", exc_info=True) - return [] - - def get_suggestion_query( - self, user_id: str, language: str = "zh", message: MessageList | None = None - ) -> list[str]: - """Get suggestion query from LLM. - Args: - user_id (str): User ID. - language (str): Language for suggestions ("zh" or "en"). - - Returns: - list[str]: The suggestion query list. - """ - if message: - further_suggestion = self._get_further_suggestion(message) - return further_suggestion - if language == "zh": - suggestion_prompt = SUGGESTION_QUERY_PROMPT_ZH - else: # English - suggestion_prompt = SUGGESTION_QUERY_PROMPT_EN - text_mem_result = super().search("my recently memories", user_id=user_id, top_k=3)[ - "text_mem" - ] - if text_mem_result: - memories = "\n".join([m.memory[:200] for m in text_mem_result[0]["memories"]]) - else: - memories = "" - message_list = [{"role": "system", "content": suggestion_prompt.format(memories=memories)}] - response = self.chat_llm.generate(message_list) - clean_response = clean_json_response(response) - response_json = json.loads(clean_response) - return response_json["query"] - - def chat( - self, - query: str, - user_id: str, - cube_id: str | None = None, - history: MessageList | None = None, - base_prompt: str | None = None, - internet_search: bool = False, - moscube: bool = False, - top_k: int = 10, - threshold: float = 0.5, - session_id: str | None = None, - ) -> str: - """ - Chat with LLM with memory references and complete response. - """ - self._load_user_cubes(user_id, self.default_cube_config) - time_start = time.time() - memories_result = super().search( - query, - user_id, - install_cube_ids=[cube_id] if cube_id else None, - top_k=top_k, - mode="fine", - internet_search=internet_search, - moscube=moscube, - session_id=session_id, - )["text_mem"] - - memories_list = [] - if memories_result: - memories_list = memories_result[0]["memories"] - memories_list = self._filter_memories_by_threshold(memories_list, threshold) - new_memories_list = [] - for m in memories_list: - m.metadata.embedding = [] - new_memories_list.append(m) - memories_list = new_memories_list - - system_prompt = super()._build_system_prompt(memories_list, base_prompt) - if history is not None: - # Use the provided history (even if it's empty) - history_info = history[-20:] - else: - # Fall back to internal chat_history - if user_id not in self.chat_history_manager: - self._register_chat_history(user_id, session_id) - history_info = self.chat_history_manager[user_id].chat_history[-20:] - current_messages = [ - {"role": "system", "content": system_prompt}, - *history_info, - {"role": "user", "content": query}, - ] - logger.info("Start to get final answer...") - response = self.chat_llm.generate(current_messages) - time_end = time.time() - self._start_post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=0.0, - current_messages=current_messages, - ) - return response, memories_list - - def chat_with_references( - self, - query: str, - user_id: str, - cube_id: str | None = None, - history: MessageList | None = None, - top_k: int = 20, - internet_search: bool = False, - moscube: bool = False, - session_id: str | None = None, - ) -> Generator[str, None, None]: - """ - Chat with LLM with memory references and streaming output. - - Args: - query (str): Query string. - user_id (str): User ID. - cube_id (str, optional): Custom cube ID for user. - history (MessageList, optional): Chat history. - - Returns: - Generator[str, None, None]: The response string generator with reference processing. - """ - - self._load_user_cubes(user_id, self.default_cube_config) - time_start = time.time() - memories_list = [] - yield f"data: {json.dumps({'type': 'status', 'data': '0'})}\n\n" - memories_result = super().search( - query, - user_id, - install_cube_ids=[cube_id] if cube_id else None, - top_k=top_k, - mode="fine", - internet_search=internet_search, - moscube=moscube, - session_id=session_id, - )["text_mem"] - - yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n" - search_time_end = time.time() - logger.info( - f"time chat: search text_mem time user_id: {user_id} time is: {search_time_end - time_start}" - ) - self._send_message_to_scheduler( - user_id=user_id, mem_cube_id=cube_id, query=query, label=QUERY_TASK_LABEL - ) - if memories_result: - memories_list = memories_result[0]["memories"] - memories_list = self._filter_memories_by_threshold(memories_list) - - reference = prepare_reference_data(memories_list) - yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" - # Build custom system prompt with relevant memories) - system_prompt = self._build_enhance_system_prompt(user_id, memories_list) - # Get chat history - if user_id not in self.chat_history_manager: - self._register_chat_history(user_id, session_id) - - chat_history = self.chat_history_manager[user_id] - if history is not None: - chat_history.chat_history = history[-20:] - current_messages = [ - {"role": "system", "content": system_prompt}, - *chat_history.chat_history, - {"role": "user", "content": query}, - ] - logger.info( - f"user_id: {user_id}, cube_id: {cube_id}, current_system_prompt: {system_prompt}" - ) - yield f"data: {json.dumps({'type': 'status', 'data': '2'})}\n\n" - # Generate response with custom prompt - past_key_values = None - response_stream = None - if self.config.enable_activation_memory: - # Handle activation memory (copy MOSCore logic) - for mem_cube_id, mem_cube in self.mem_cubes.items(): - if mem_cube.act_mem and mem_cube_id == cube_id: - kv_cache = next(iter(mem_cube.act_mem.get_all()), None) - past_key_values = ( - kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None - ) - if past_key_values is not None: - logger.info("past_key_values is not None will apply to chat") - else: - logger.info("past_key_values is None will not apply to chat") - break - if self.config.chat_model.backend == "huggingface": - response_stream = self.chat_llm.generate_stream( - current_messages, past_key_values=past_key_values - ) - elif self.config.chat_model.backend == "vllm": - response_stream = self.chat_llm.generate_stream(current_messages) - else: - if self.config.chat_model.backend in ["huggingface", "vllm", "openai"]: - response_stream = self.chat_llm.generate_stream(current_messages) - else: - response_stream = self.chat_llm.generate(current_messages) - - time_end = time.time() - chat_time_end = time.time() - logger.info( - f"time chat: chat time user_id: {user_id} time is: {chat_time_end - search_time_end}" - ) - # Simulate streaming output with proper reference handling using tiktoken - - # Initialize buffer for streaming - buffer = "" - full_response = "" - token_count = 0 - # Use tiktoken for proper token-based chunking - if self.config.chat_model.backend not in ["huggingface", "vllm", "openai"]: - # For non-huggingface backends, we need to collect the full response first - full_response_text = "" - for chunk in response_stream: - if chunk in ["", ""]: - continue - full_response_text += chunk - response_stream = self._chunk_response_with_tiktoken(full_response_text, chunk_size=5) - for chunk in response_stream: - if chunk in ["", ""]: - continue - token_count += 1 - buffer += chunk - full_response += chunk - - # Process buffer to ensure complete reference tags - processed_chunk, remaining_buffer = process_streaming_references_complete(buffer) - - if processed_chunk: - chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" - yield chunk_data - buffer = remaining_buffer - - # Process any remaining buffer - if buffer: - processed_chunk, remaining_buffer = process_streaming_references_complete(buffer) - if processed_chunk: - chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" - yield chunk_data - - # set kvcache improve speed - speed_improvement = round(float((len(system_prompt) / 2) * 0.0048 + 44.5), 1) - total_time = round(float(time_end - time_start), 1) - - yield f"data: {json.dumps({'type': 'time', 'data': {'total_time': total_time, 'speed_improvement': f'{speed_improvement}%'}})}\n\n" - # get further suggestion - current_messages.append({"role": "assistant", "content": full_response}) - further_suggestion = self._get_further_suggestion(current_messages) - logger.info(f"further_suggestion: {further_suggestion}") - yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion})}\n\n" - yield f"data: {json.dumps({'type': 'end'})}\n\n" - - # Asynchronous processing of logs, notifications and memory additions - self._start_post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=full_response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=speed_improvement, - current_messages=current_messages, - ) - - def get_all( - self, - user_id: str, - memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"], - mem_cube_ids: list[str] | None = None, - ) -> list[dict[str, Any]]: - """Get all memory items for a user. - - Args: - user_id (str): The ID of the user. - cube_id (str | None, optional): The ID of the cube. Defaults to None. - memory_type (Literal["text_mem", "act_mem", "param_mem"]): The type of memory to get. - - Returns: - list[dict[str, Any]]: A list of memory items with cube_id and memories structure. - """ - - # Load user cubes if not already loaded - self._load_user_cubes(user_id, self.default_cube_config) - time_start = time.time() - memory_list = super().get_all( - mem_cube_id=mem_cube_ids[0] if mem_cube_ids else None, user_id=user_id - )[memory_type] - get_all_time_end = time.time() - logger.info( - f"time get_all: get_all time user_id: {user_id} time is: {get_all_time_end - time_start}" - ) - reformat_memory_list = [] - if memory_type == "text_mem": - for memory in memory_list: - memories = remove_embedding_recursive(memory["memories"]) - custom_type_ratios = { - "WorkingMemory": 0.20, - "LongTermMemory": 0.40, - "UserMemory": 0.40, - } - tree_result, node_type_count = convert_graph_to_tree_forworkmem( - memories, target_node_count=200, type_ratios=custom_type_ratios - ) - # Ensure all node IDs are unique in the tree structure - tree_result = ensure_unique_tree_ids(tree_result) - memories_filtered = filter_nodes_by_tree_ids(tree_result, memories) - children = tree_result["children"] - children_sort = sort_children_by_memory_type(children) - tree_result["children"] = children_sort - memories_filtered["tree_structure"] = tree_result - reformat_memory_list.append( - { - "cube_id": memory["cube_id"], - "memories": [memories_filtered], - "memory_statistics": node_type_count, - } - ) - elif memory_type == "act_mem": - memories_list = [] - act_mem_params = self.mem_cubes[mem_cube_ids[0]].act_mem.get_all() - if act_mem_params: - memories_data = act_mem_params[0].model_dump() - records = memories_data.get("records", []) - for record in records["text_memories"]: - memories_list.append( - { - "id": memories_data["id"], - "text": record, - "create_time": records["timestamp"], - "size": random.randint(1, 20), - "modify_times": 1, - } - ) - reformat_memory_list.append( - { - "cube_id": "xxxxxxxxxxxxxxxx" if not mem_cube_ids else mem_cube_ids[0], - "memories": memories_list, - } - ) - elif memory_type == "para_mem": - act_mem_params = self.mem_cubes[mem_cube_ids[0]].act_mem.get_all() - logger.info(f"act_mem_params: {act_mem_params}") - reformat_memory_list.append( - { - "cube_id": "xxxxxxxxxxxxxxxx" if not mem_cube_ids else mem_cube_ids[0], - "memories": act_mem_params[0].model_dump(), - } - ) - make_format_time_end = time.time() - logger.info( - f"time get_all: make_format time user_id: {user_id} time is: {make_format_time_end - get_all_time_end}" - ) - return reformat_memory_list - - def _get_subgraph( - self, query: str, mem_cube_id: str, user_id: str | None = None, top_k: int = 5 - ) -> list[dict[str, Any]]: - result = {"para_mem": [], "act_mem": [], "text_mem": []} - if self.config.enable_textual_memory and self.mem_cubes[mem_cube_id].text_mem: - result["text_mem"].append( - { - "cube_id": mem_cube_id, - "memories": self.mem_cubes[mem_cube_id].text_mem.get_relevant_subgraph( - query, top_k=top_k - ), - } - ) - return result - - def get_subgraph( - self, - user_id: str, - query: str, - mem_cube_ids: list[str] | None = None, - top_k: int = 20, - ) -> list[dict[str, Any]]: - """Get all memory items for a user. - - Args: - user_id (str): The ID of the user. - cube_id (str | None, optional): The ID of the cube. Defaults to None. - mem_cube_ids (list[str], optional): The IDs of the cubes. Defaults to None. - - Returns: - list[dict[str, Any]]: A list of memory items with cube_id and memories structure. - """ - - # Load user cubes if not already loaded - self._load_user_cubes(user_id, self.default_cube_config) - memory_list = self._get_subgraph( - query=query, mem_cube_id=mem_cube_ids[0], user_id=user_id, top_k=top_k - )["text_mem"] - reformat_memory_list = [] - for memory in memory_list: - memories = remove_embedding_recursive(memory["memories"]) - custom_type_ratios = {"WorkingMemory": 0.20, "LongTermMemory": 0.40, "UserMemory": 0.4} - tree_result, node_type_count = convert_graph_to_tree_forworkmem( - memories, target_node_count=150, type_ratios=custom_type_ratios - ) - # Ensure all node IDs are unique in the tree structure - tree_result = ensure_unique_tree_ids(tree_result) - memories_filtered = filter_nodes_by_tree_ids(tree_result, memories) - children = tree_result["children"] - children_sort = sort_children_by_memory_type(children) - tree_result["children"] = children_sort - memories_filtered["tree_structure"] = tree_result - reformat_memory_list.append( - { - "cube_id": memory["cube_id"], - "memories": [memories_filtered], - "memory_statistics": node_type_count, - } - ) - - return reformat_memory_list - - def search( - self, - query: str, - user_id: str, - install_cube_ids: list[str] | None = None, - top_k: int = 10, - mode: Literal["fast", "fine"] = "fast", - session_id: str | None = None, - ): - """Search memories for a specific user.""" - - # Load user cubes if not already loaded - time_start = time.time() - self._load_user_cubes(user_id, self.default_cube_config) - load_user_cubes_time_end = time.time() - logger.info( - f"time search: load_user_cubes time user_id: {user_id} time is: {load_user_cubes_time_end - time_start}" - ) - search_result = super().search( - query, user_id, install_cube_ids, top_k, mode=mode, session_id=session_id - ) - search_time_end = time.time() - logger.info( - f"time search: search text_mem time user_id: {user_id} time is: {search_time_end - load_user_cubes_time_end}" - ) - text_memory_list = search_result["text_mem"] - reformat_memory_list = [] - for memory in text_memory_list: - memories_list = [] - for data in memory["memories"]: - memories = data.model_dump() - memories["ref_id"] = f"[{memories['id'].split('-')[0]}]" - memories["metadata"]["embedding"] = [] - memories["metadata"]["sources"] = [] - memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]" - memories["metadata"]["id"] = memories["id"] - memories["metadata"]["memory"] = memories["memory"] - memories_list.append(memories) - reformat_memory_list.append({"cube_id": memory["cube_id"], "memories": memories_list}) - logger.info(f"search memory list is : {reformat_memory_list}") - search_result["text_mem"] = reformat_memory_list - - pref_memory_list = search_result["pref_mem"] - reformat_pref_memory_list = [] - for memory in pref_memory_list: - memories_list = [] - for data in memory["memories"]: - memories = data.model_dump() - memories["ref_id"] = f"[{memories['id'].split('-')[0]}]" - memories["metadata"]["embedding"] = [] - memories["metadata"]["sources"] = [] - memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]" - memories["metadata"]["id"] = memories["id"] - memories["metadata"]["memory"] = memories["memory"] - memories_list.append(memories) - reformat_pref_memory_list.append( - {"cube_id": memory["cube_id"], "memories": memories_list} - ) - search_result["pref_mem"] = reformat_pref_memory_list - time_end = time.time() - logger.info( - f"time search: total time for user_id: {user_id} time is: {time_end - time_start}" - ) - return search_result - - def add( - self, - user_id: str, - messages: MessageList | None = None, - memory_content: str | None = None, - doc_path: str | None = None, - mem_cube_id: str | None = None, - source: str | None = None, - user_profile: bool = False, - session_id: str | None = None, - task_id: str | None = None, # Add task_id parameter - ): - """Add memory for a specific user.""" - - # Load user cubes if not already loaded - self._load_user_cubes(user_id, self.default_cube_config) - result = super().add( - messages, - memory_content, - doc_path, - mem_cube_id, - user_id, - session_id=session_id, - task_id=task_id, - ) - if user_profile: - try: - user_interests = memory_content.split("'userInterests': '")[1].split("', '")[0] - user_interests = user_interests.replace(",", " ") - user_profile_memories = self.mem_cubes[ - mem_cube_id - ].text_mem.internet_retriever.retrieve_from_internet(query=user_interests, top_k=5) - for memory in user_profile_memories: - self.mem_cubes[mem_cube_id].text_mem.add(memory) - except Exception as e: - logger.error( - f"Failed to retrieve user profile: {e}, memory_content: {memory_content}" - ) - - return result - - def list_users(self) -> list: - """List all registered users.""" - return self.user_manager.list_users() - - def get_user_info(self, user_id: str) -> dict: - """Get user information including accessible cubes.""" - # Use MOSCore's built-in user validation - # Validate user access - self._validate_user_access(user_id) - - result = super().get_user_info() - - return result - - def share_cube_with_user(self, cube_id: str, owner_user_id: str, target_user_id: str) -> bool: - """Share a cube with another user.""" - # Use MOSCore's built-in cube access validation - self._validate_cube_access(owner_user_id, cube_id) - - result = super().share_cube_with_user(cube_id, target_user_id) - - return result - - def clear_user_chat_history(self, user_id: str) -> None: - """Clear chat history for a specific user.""" - # Validate user access - self._validate_user_access(user_id) - - super().clear_messages(user_id) - - def update_user_config(self, user_id: str, config: MOSConfig) -> bool: - """Update user configuration. - - Args: - user_id (str): The user ID. - config (MOSConfig): The new configuration. - - Returns: - bool: True if successful, False otherwise. - """ - try: - # Save to persistent storage - success = self.user_manager.save_user_config(user_id, config) - if success: - # Update in-memory config - self.user_configs[user_id] = config - logger.info(f"Updated configuration for user {user_id}") - - return success - except Exception as e: - logger.error(f"Failed to update user config for {user_id}: {e}") - return False - - def get_user_config(self, user_id: str) -> MOSConfig | None: - """Get user configuration. - - Args: - user_id (str): The user ID. - - Returns: - MOSConfig | None: The user's configuration or None if not found. - """ - return self.user_manager.get_user_config(user_id) - - def get_active_user_count(self) -> int: - """Get the number of active user configurations in memory.""" - return len(self.user_configs) - - def get_user_instance_info(self) -> dict[str, Any]: - """Get information about user configurations in memory.""" - return { - "active_instances": len(self.user_configs), - "max_instances": self.max_user_instances, - "user_ids": list(self.user_configs.keys()), - "lru_order": list(self.user_configs.keys()), # OrderedDict maintains insertion order - } diff --git a/src/memos/mem_os/product_server.py b/src/memos/mem_os/product_server.py deleted file mode 100644 index 80aefea85..000000000 --- a/src/memos/mem_os/product_server.py +++ /dev/null @@ -1,457 +0,0 @@ -import asyncio -import time - -from datetime import datetime -from typing import Literal - -from memos.context.context import ContextThread -from memos.llms.base import BaseLLM -from memos.log import get_logger -from memos.mem_cube.navie import NaiveMemCube -from memos.mem_os.product import _format_mem_block -from memos.mem_reader.base import BaseMemReader -from memos.memories.textual.item import TextualMemoryItem -from memos.templates.mos_prompts import ( - get_memos_prompt, -) -from memos.types import MessageList - - -logger = get_logger(__name__) - - -class MOSServer: - def __init__( - self, - mem_reader: BaseMemReader | None = None, - llm: BaseLLM | None = None, - online_bot: bool = False, - ): - self.mem_reader = mem_reader - self.chat_llm = llm - self.online_bot = online_bot - - def chat( - self, - query: str, - user_id: str, - cube_id: str | None = None, - mem_cube: NaiveMemCube | None = None, - history: MessageList | None = None, - base_prompt: str | None = None, - internet_search: bool = False, - moscube: bool = False, - top_k: int = 10, - threshold: float = 0.5, - session_id: str | None = None, - ) -> str: - """ - Chat with LLM with memory references and complete response. - """ - time_start = time.time() - memories_result = mem_cube.text_mem.search( - query=query, - user_name=cube_id, - top_k=top_k, - mode="fine", - manual_close_internet=not internet_search, - moscube=moscube, - info={ - "user_id": user_id, - "session_id": session_id, - "chat_history": history, - }, - ) - - memories_list = [] - if memories_result: - memories_list = self._filter_memories_by_threshold(memories_result, threshold) - new_memories_list = [] - for m in memories_list: - m.metadata.embedding = [] - new_memories_list.append(m) - memories_list = new_memories_list - system_prompt = self._build_system_prompt(memories_list, base_prompt) - - history_info = [] - if history: - history_info = history[-20:] - current_messages = [ - {"role": "system", "content": system_prompt}, - *history_info, - {"role": "user", "content": query}, - ] - response = self.chat_llm.generate(current_messages) - time_end = time.time() - self._start_post_chat_processing( - user_id=user_id, - cube_id=cube_id, - session_id=session_id, - query=query, - full_response=response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=0.0, - current_messages=current_messages, - mem_cube=mem_cube, - history=history, - ) - return response, memories_list - - def add( - self, - user_id: str, - cube_id: str, - mem_cube: NaiveMemCube, - messages: MessageList, - session_id: str | None = None, - history: MessageList | None = None, - ) -> list[str]: - memories = self.mem_reader.get_memory( - [messages], - type="chat", - info={ - "user_id": user_id, - "session_id": session_id, - "chat_history": history, - }, - ) - flattened_memories = [mm for m in memories for mm in m] - mem_id_list: list[str] = mem_cube.text_mem.add( - flattened_memories, - user_name=cube_id, - ) - return mem_id_list - - def search( - self, - user_id: str, - cube_id: str, - session_id: str | None = None, - ) -> None: - NotImplementedError("Not implemented") - - def _filter_memories_by_threshold( - self, - memories: list[TextualMemoryItem], - threshold: float = 0.30, - min_num: int = 3, - memory_type: Literal["OuterMemory"] = "OuterMemory", - ) -> list[TextualMemoryItem]: - """ - Filter memories by threshold and type, at least min_num memories for Non-OuterMemory. - Args: - memories: list[TextualMemoryItem], - threshold: float, - min_num: int, - memory_type: Literal["OuterMemory"], - Returns: - list[TextualMemoryItem] - """ - sorted_memories = sorted(memories, key=lambda m: m.metadata.relativity, reverse=True) - filtered_person = [m for m in memories if m.metadata.memory_type != memory_type] - filtered_outer = [m for m in memories if m.metadata.memory_type == memory_type] - filtered = [] - per_memory_count = 0 - for m in sorted_memories: - if m.metadata.relativity >= threshold: - if m.metadata.memory_type != memory_type: - per_memory_count += 1 - filtered.append(m) - if len(filtered) < min_num: - filtered = filtered_person[:min_num] + filtered_outer[:min_num] - else: - if per_memory_count < min_num: - filtered += filtered_person[per_memory_count:min_num] - filtered_memory = sorted(filtered, key=lambda m: m.metadata.relativity, reverse=True) - return filtered_memory - - def _build_base_system_prompt( - self, - base_prompt: str | None = None, - tone: str = "friendly", - verbosity: str = "mid", - mode: str = "enhance", - ) -> str: - """ - Build base system prompt without memory references. - """ - now = datetime.now() - formatted_date = now.strftime("%Y-%m-%d (%A)") - sys_body = get_memos_prompt(date=formatted_date, tone=tone, verbosity=verbosity, mode=mode) - prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" - return prefix + sys_body - - def _build_system_prompt( - self, - memories: list[TextualMemoryItem] | list[str] | None = None, - base_prompt: str | None = None, - **kwargs, - ) -> str: - """Build system prompt with optional memories context.""" - if base_prompt is None: - base_prompt = ( - "You are a knowledgeable and helpful AI assistant. " - "You have access to conversation memories that help you provide more personalized responses. " - "Use the memories to understand the user's context, preferences, and past interactions. " - "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories." - ) - - memory_context = "" - if memories: - memory_list = [] - for i, memory in enumerate(memories, 1): - if isinstance(memory, TextualMemoryItem): - text_memory = memory.memory - else: - if not isinstance(memory, str): - logger.error("Unexpected memory type.") - text_memory = memory - memory_list.append(f"{i}. {text_memory}") - memory_context = "\n".join(memory_list) - - if "{memories}" in base_prompt: - return base_prompt.format(memories=memory_context) - elif base_prompt and memories: - # For backward compatibility, append memories if no placeholder is found - memory_context_with_header = "\n\n## Memories:\n" + memory_context - return base_prompt + memory_context_with_header - return base_prompt - - def _build_memory_context( - self, - memories_all: list[TextualMemoryItem], - mode: str = "enhance", - ) -> str: - """ - Build memory context to be included in user message. - """ - if not memories_all: - return "" - - mem_block_o, mem_block_p = _format_mem_block(memories_all) - - if mode == "enhance": - return ( - "# Memories\n## PersonalMemory (ordered)\n" - + mem_block_p - + "\n## OuterMemory (ordered)\n" - + mem_block_o - + "\n\n" - ) - else: - mem_block = mem_block_o + "\n" + mem_block_p - return "# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + mem_block + "\n\n" - - def _extract_references_from_response(self, response: str) -> tuple[str, list[dict]]: - """ - Extract reference information from the response and return clean text. - - Args: - response (str): The complete response text. - - Returns: - tuple[str, list[dict]]: A tuple containing: - - clean_text: Text with reference markers removed - - references: List of reference information - """ - import re - - try: - references = [] - # Pattern to match [refid:memoriesID] - pattern = r"\[(\d+):([^\]]+)\]" - - matches = re.findall(pattern, response) - for ref_number, memory_id in matches: - references.append({"memory_id": memory_id, "reference_number": int(ref_number)}) - - # Remove all reference markers from the text to get clean text - clean_text = re.sub(pattern, "", response) - - # Clean up any extra whitespace that might be left after removing markers - clean_text = re.sub(r"\s+", " ", clean_text).strip() - - return clean_text, references - except Exception as e: - logger.error(f"Error extracting references from response: {e}", exc_info=True) - return response, [] - - async def _post_chat_processing( - self, - user_id: str, - cube_id: str, - query: str, - full_response: str, - system_prompt: str, - time_start: float, - time_end: float, - speed_improvement: float, - current_messages: list, - mem_cube: NaiveMemCube | None = None, - session_id: str | None = None, - history: MessageList | None = None, - ) -> None: - """ - Asynchronous processing of logs, notifications and memory additions - """ - try: - logger.info( - f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}" - ) - logger.info(f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}") - - clean_response, extracted_references = self._extract_references_from_response( - full_response - ) - logger.info(f"Extracted {len(extracted_references)} references from response") - - # Send chat report notifications asynchronously - if self.online_bot: - try: - from memos.memos_tools.notification_utils import ( - send_online_bot_notification_async, - ) - - # Prepare notification data - chat_data = { - "query": query, - "user_id": user_id, - "cube_id": cube_id, - "system_prompt": system_prompt, - "full_response": full_response, - } - - system_data = { - "references": extracted_references, - "time_start": time_start, - "time_end": time_end, - "speed_improvement": speed_improvement, - } - - emoji_config = {"chat": "๐Ÿ’ฌ", "system_info": "๐Ÿ“Š"} - - await send_online_bot_notification_async( - online_bot=self.online_bot, - header_name="MemOS Chat Report", - sub_title_name="chat_with_references", - title_color="#00956D", - other_data1=chat_data, - other_data2=system_data, - emoji=emoji_config, - ) - except Exception as e: - logger.warning(f"Failed to send chat notification (async): {e}") - - self.add( - user_id=user_id, - cube_id=cube_id, - mem_cube=mem_cube, - session_id=session_id, - history=history, - messages=[ - { - "role": "user", - "content": query, - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - { - "role": "assistant", - "content": clean_response, # Store clean text without reference markers - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - ], - ) - - logger.info(f"Post-chat processing completed for user {user_id}") - - except Exception as e: - logger.error(f"Error in post-chat processing for user {user_id}: {e}", exc_info=True) - - def _start_post_chat_processing( - self, - user_id: str, - cube_id: str, - query: str, - full_response: str, - system_prompt: str, - time_start: float, - time_end: float, - speed_improvement: float, - current_messages: list, - mem_cube: NaiveMemCube | None = None, - session_id: str | None = None, - history: MessageList | None = None, - ) -> None: - """ - Asynchronous processing of logs, notifications and memory additions, handle synchronous and asynchronous environments - """ - - def run_async_in_thread(): - """Running asynchronous tasks in a new thread""" - try: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete( - self._post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=full_response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=speed_improvement, - current_messages=current_messages, - mem_cube=mem_cube, - session_id=session_id, - history=history, - ) - ) - finally: - loop.close() - except Exception as e: - logger.error( - f"Error in thread-based post-chat processing for user {user_id}: {e}", - exc_info=True, - ) - - try: - # Try to get the current event loop - asyncio.get_running_loop() - # Create task and store reference to prevent garbage collection - task = asyncio.create_task( - self._post_chat_processing( - user_id=user_id, - cube_id=cube_id, - query=query, - full_response=full_response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=speed_improvement, - current_messages=current_messages, - ) - ) - # Add exception handling for the background task - task.add_done_callback( - lambda t: ( - logger.error( - f"Error in background post-chat processing for user {user_id}: {t.exception()}", - exc_info=True, - ) - if t.exception() - else None - ) - ) - except RuntimeError: - # No event loop, run in a new thread with context propagation - thread = ContextThread( - target=run_async_in_thread, - name=f"PostChatProcessing-{user_id}", - # Set as a daemon thread to avoid blocking program exit - daemon=True, - ) - thread.start() diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index ec431c253..671190e6f 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -47,13 +47,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: graph_db_backend_map = { "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), "neo4j": APIConfig.get_neo4j_config(user_id=user_id), - "nebular": APIConfig.get_nebular_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), "postgres": APIConfig.get_postgres_config(user_id=user_id), } # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars - graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower() + graph_db_backend = os.getenv( + "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community") + ).lower() return GraphDBConfigFactory.model_validate( { "backend": graph_db_backend, diff --git a/tests/api/test_product_router.py b/tests/api/test_product_router.py deleted file mode 100644 index 857b290c5..000000000 --- a/tests/api/test_product_router.py +++ /dev/null @@ -1,422 +0,0 @@ -""" -Unit tests for product_router input/output format validation. - -This module tests that the product_router endpoints correctly validate -input request formats and return properly formatted responses. -""" - -from unittest.mock import Mock, patch - -import pytest - -from fastapi.testclient import TestClient - -# Patch the MOS_PRODUCT_INSTANCE directly after import -# Patch MOS_PRODUCT_INSTANCE and MOSProduct so we can test the FastAPI router -# without initializing the full MemOS product stack. -import memos.api.routers.product_router as pr_module - - -_mock_mos_instance = Mock() -pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance -pr_module.get_mos_product_instance = lambda: _mock_mos_instance -with patch("memos.mem_os.product.MOSProduct", return_value=_mock_mos_instance): - from memos.api import product_api - - -@pytest.fixture(scope="module") -def mock_mos_product_instance(): - """Mock get_mos_product_instance for all tests.""" - # Ensure the mock is set - pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance - pr_module.get_mos_product_instance = lambda: _mock_mos_instance - yield product_api.app, _mock_mos_instance - - -@pytest.fixture -def client(mock_mos_product_instance): - """Create test client for product_api.""" - app, _ = mock_mos_product_instance - return TestClient(app) - - -@pytest.fixture -def mock_mos_product(mock_mos_product_instance): - """Get the mocked MOSProduct instance.""" - _, mock_instance = mock_mos_product_instance - # Ensure get_mos_product_instance returns this mock - import memos.api.routers.product_router as pr_module - - pr_module.get_mos_product_instance = lambda: mock_instance - pr_module.MOS_PRODUCT_INSTANCE = mock_instance - return mock_instance - - -@pytest.fixture(autouse=True) -def setup_mock_mos_product(mock_mos_product): - """Set up default return values for MOSProduct methods.""" - # Set up default return values for methods - mock_mos_product.search.return_value = {"text_mem": [], "act_mem": [], "para_mem": []} - mock_mos_product.add.return_value = None - mock_mos_product.chat.return_value = ("test response", []) - mock_mos_product.chat_with_references.return_value = iter( - ['data: {"type": "content", "data": "test"}\n\n'] - ) - # Ensure get_all and get_subgraph return proper list format (MemoryResponse expects list) - default_memory_result = [{"cube_id": "test_cube", "memories": []}] - mock_mos_product.get_all.return_value = default_memory_result - mock_mos_product.get_subgraph.return_value = default_memory_result - mock_mos_product.get_suggestion_query.return_value = ["suggestion1", "suggestion2"] - # Ensure get_mos_product_instance returns the mock - import memos.api.routers.product_router as pr_module - - pr_module.get_mos_product_instance = lambda: mock_mos_product - - -class TestProductRouterSearch: - """Test /search endpoint input/output format.""" - - def test_search_valid_input_output(self, mock_mos_product, client): - """Test search endpoint with valid input returns correct output format.""" - request_data = { - "user_id": "test_user", - "query": "test query", - "mem_cube_id": "test_cube", - "top_k": 10, - } - - response = client.post("/product/search", json=request_data) - - assert response.status_code == 200 - data = response.json() - - # Validate response structure - assert "code" in data - assert "message" in data - assert "data" in data - assert data["code"] == 200 - assert isinstance(data["data"], dict) - - # Verify MOSProduct.search was called with correct parameters - mock_mos_product.search.assert_called_once() - call_kwargs = mock_mos_product.search.call_args[1] - assert call_kwargs["user_id"] == "test_user" - assert call_kwargs["query"] == "test query" - - def test_search_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test search endpoint with missing required field.""" - request_data = { - "query": "test query", - } - - response = client.post("/product/search", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - def test_search_response_format(self, mock_mos_product, client): - """Test search endpoint returns SearchResponse format.""" - mock_mos_product.search.return_value = { - "text_mem": [{"cube_id": "test_cube", "memories": []}], - "act_mem": [], - "para_mem": [], - } - - request_data = { - "user_id": "test_user", - "query": "test query", - } - - response = client.post("/product/search", json=request_data) - - assert response.status_code == 200 - data = response.json() - assert data["message"] == "Search completed successfully" - assert isinstance(data["data"], dict) - assert "text_mem" in data["data"] - - -class TestProductRouterAdd: - """Test /add endpoint input/output format.""" - - def test_add_valid_input_output(self, mock_mos_product, client): - """Test add endpoint with valid input returns correct output format.""" - request_data = { - "user_id": "test_user", - "memory_content": "test memory content", - "mem_cube_id": "test_cube", - } - - response = client.post("/product/add", json=request_data) - - assert response.status_code == 200 - data = response.json() - - # Validate response structure - assert "code" in data - assert "message" in data - assert "data" in data - assert data["code"] == 200 - assert data["data"] is None # SimpleResponse has None data - - # Verify MOSProduct.add was called with correct parameters - mock_mos_product.add.assert_called_once() - call_kwargs = mock_mos_product.add.call_args[1] - assert call_kwargs["user_id"] == "test_user" - assert call_kwargs["memory_content"] == "test memory content" - - def test_add_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test add endpoint with missing required field.""" - request_data = { - "memory_content": "test memory content", - } - - response = client.post("/product/add", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - def test_add_response_format(self, mock_mos_product, client): - """Test add endpoint returns SimpleResponse format.""" - request_data = { - "user_id": "test_user", - "memory_content": "test memory content", - } - - response = client.post("/product/add", json=request_data) - - assert response.status_code == 200 - data = response.json() - assert data["message"] == "Memory created successfully" - assert data["data"] is None - - -class TestProductRouterChatComplete: - """Test /chat/complete endpoint input/output format.""" - - def test_chat_complete_valid_input_output(self, mock_mos_product, client): - """Test chat/complete endpoint with valid input returns correct output format.""" - request_data = { - "user_id": "test_user", - "query": "test query", - "mem_cube_id": "test_cube", - } - - response = client.post("/product/chat/complete", json=request_data) - - assert response.status_code == 200 - data = response.json() - - # Validate response structure - assert "message" in data - assert "data" in data - assert isinstance(data["data"], dict) - assert "response" in data["data"] - assert "references" in data["data"] - - # Verify MOSProduct.chat was called with correct parameters - mock_mos_product.chat.assert_called_once() - call_kwargs = mock_mos_product.chat.call_args[1] - assert call_kwargs["user_id"] == "test_user" - assert call_kwargs["query"] == "test query" - - def test_chat_complete_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test chat/complete endpoint with missing required field.""" - request_data = { - "query": "test query", - } - - response = client.post("/product/chat/complete", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - def test_chat_complete_response_format(self, mock_mos_product, client): - """Test chat/complete endpoint returns correct format.""" - mock_mos_product.chat.return_value = ("test response", [{"id": "ref1"}]) - - request_data = { - "user_id": "test_user", - "query": "test query", - } - - response = client.post("/product/chat/complete", json=request_data) - - assert response.status_code == 200 - data = response.json() - assert data["message"] == "Chat completed successfully" - assert isinstance(data["data"]["response"], str) - assert isinstance(data["data"]["references"], list) - - -class TestProductRouterChat: - """Test /chat endpoint input/output format (SSE stream).""" - - def test_chat_valid_input_output(self, mock_mos_product, client): - """Test chat endpoint with valid input returns SSE stream.""" - request_data = { - "user_id": "test_user", - "query": "test query", - "mem_cube_id": "test_cube", - } - - response = client.post("/product/chat", json=request_data) - - assert response.status_code == 200 - assert "text/event-stream" in response.headers["content-type"] - - # Verify MOSProduct.chat_with_references was called - mock_mos_product.chat_with_references.assert_called_once() - call_kwargs = mock_mos_product.chat_with_references.call_args[1] - assert call_kwargs["user_id"] == "test_user" - assert call_kwargs["query"] == "test query" - - def test_chat_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test chat endpoint with missing required field.""" - request_data = { - "query": "test query", - } - - response = client.post("/product/chat", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - -class TestProductRouterSuggestions: - """Test /suggestions endpoint input/output format.""" - - def test_suggestions_valid_input_output(self, mock_mos_product, client): - """Test suggestions endpoint with valid input returns correct output format.""" - request_data = { - "user_id": "test_user", - "mem_cube_id": "test_cube", - "language": "zh", - } - - response = client.post("/product/suggestions", json=request_data) - - assert response.status_code == 200 - data = response.json() - - # Validate response structure - assert "code" in data - assert "message" in data - assert "data" in data - assert data["code"] == 200 - assert isinstance(data["data"], dict) - assert "query" in data["data"] - - # Verify MOSProduct.get_suggestion_query was called - mock_mos_product.get_suggestion_query.assert_called_once() - call_kwargs = mock_mos_product.get_suggestion_query.call_args[1] - assert call_kwargs["user_id"] == "test_user" - - def test_suggestions_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test suggestions endpoint with missing required field.""" - request_data = { - "mem_cube_id": "test_cube", - } - - response = client.post("/product/suggestions", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - def test_suggestions_response_format(self, mock_mos_product, client): - """Test suggestions endpoint returns SuggestionResponse format.""" - mock_mos_product.get_suggestion_query.return_value = [ - "suggestion1", - "suggestion2", - "suggestion3", - ] - - request_data = { - "user_id": "test_user", - "mem_cube_id": "test_cube", - "language": "en", - } - - response = client.post("/product/suggestions", json=request_data) - - assert response.status_code == 200 - data = response.json() - assert data["message"] == "Suggestions retrieved successfully" - assert isinstance(data["data"], dict) - assert isinstance(data["data"]["query"], list) - - -class TestProductRouterGetAll: - """Test /get_all endpoint input/output format.""" - - def test_get_all_valid_input_output(self, mock_mos_product, client): - """Test get_all endpoint with valid input returns correct output format.""" - request_data = { - "user_id": "test_user", - "memory_type": "text_mem", - } - - response = client.post("/product/get_all", json=request_data) - - assert response.status_code == 200 - data = response.json() - - # Validate response structure - assert "code" in data - assert "message" in data - assert "data" in data - assert data["code"] == 200 - assert isinstance(data["data"], list) - - # Verify MOSProduct.get_all was called - mock_mos_product.get_all.assert_called_once() - call_kwargs = mock_mos_product.get_all.call_args[1] - assert call_kwargs["user_id"] == "test_user" - assert call_kwargs["memory_type"] == "text_mem" - - def test_get_all_with_search_query(self, mock_mos_product, client): - """Test get_all endpoint with search_query uses get_subgraph.""" - # Reset mock call counts - mock_mos_product.get_all.reset_mock() - mock_mos_product.get_subgraph.reset_mock() - - request_data = { - "user_id": "test_user", - "memory_type": "text_mem", - "search_query": "test query", - } - - response = client.post("/product/get_all", json=request_data) - - assert response.status_code == 200 - # Verify get_subgraph was called instead of get_all - mock_mos_product.get_subgraph.assert_called_once() - mock_mos_product.get_all.assert_not_called() - - def test_get_all_invalid_input_missing_user_id(self, mock_mos_product, client): - """Test get_all endpoint with missing required field.""" - request_data = { - "memory_type": "text_mem", - } - - response = client.post("/product/get_all", json=request_data) - - # Should return validation error - assert response.status_code == 422 - - def test_get_all_response_format(self, mock_mos_product, client): - """Test get_all endpoint returns MemoryResponse format.""" - mock_mos_product.get_all.return_value = [{"cube_id": "test_cube", "memories": []}] - - request_data = { - "user_id": "test_user", - "memory_type": "text_mem", - } - - response = client.post("/product/get_all", json=request_data) - - assert response.status_code == 200 - data = response.json() - assert data["message"] == "Memories retrieved successfully" - assert isinstance(data["data"], list) - assert len(data["data"]) > 0 diff --git a/tests/api/test_start_api.py b/tests/api/test_start_api.py deleted file mode 100644 index e1ffcd74b..000000000 --- a/tests/api/test_start_api.py +++ /dev/null @@ -1,401 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest - -from fastapi.testclient import TestClient - -from memos.api.start_api import app -from memos.mem_user.user_manager import UserRole - - -client = TestClient(app) - -# Mock data -MOCK_MESSAGE = {"role": "user", "content": "test message"} -MOCK_MEMORY_CREATE = { - "messages": [MOCK_MESSAGE], - "mem_cube_id": "test_cube", - "user_id": "test_user", -} -MOCK_MEMORY_CONTENT = { - "memory_content": "test memory content", - "mem_cube_id": "test_cube", - "user_id": "test_user", -} -MOCK_DOC_PATH = {"doc_path": "/path/to/doc", "mem_cube_id": "test_cube", "user_id": "test_user"} -MOCK_SEARCH_REQUEST = { - "query": "test query", - "user_id": "test_user", - "install_cube_ids": ["test_cube"], -} -MOCK_MEMCUBE_REGISTER = { - "mem_cube_name_or_path": "test_cube_path", - "mem_cube_id": "test_cube", - "user_id": "test_user", -} -MOCK_CHAT_REQUEST = {"query": "test chat query", "user_id": "test_user"} -MOCK_USER_CREATE = {"user_id": "test_user", "user_name": "Test User", "role": "USER"} -MOCK_CUBE_SHARE = {"target_user_id": "target_user"} -MOCK_CONFIG = { - "user_id": "test_user", - "session_id": "test_session", - "enable_textual_memory": True, - "enable_activation_memory": False, - "top_k": 5, - "chat_model": { - "backend": "openai", - "config": { - "model_name_or_path": "gpt-3.5-turbo", - "api_key": "test_key", - "temperature": 0.7, - "api_base": "https://api.openai.com/v1", - }, - }, -} - - -@pytest.fixture -def mock_mos(): - """Mock MOS instance for testing.""" - with patch("memos.api.start_api.get_mos_instance") as mock_get_mos: - # Create a mock MOS instance - mock_instance = Mock() - - # Set up default return values for methods - mock_instance.search.return_value = {"text_mem": [], "act_mem": [], "para_mem": []} - mock_instance.get_all.return_value = {"text_mem": [], "act_mem": [], "para_mem": []} - mock_instance.get.return_value = {"memory": "test memory"} - mock_instance.chat.return_value = "test response" - mock_instance.list_users.return_value = [] - mock_instance.get_user_info.return_value = { - "user_id": "test_user", - "user_name": "Test User", - "role": "user", - "accessible_cubes": [], - } - mock_instance.create_user.return_value = "test_user" - mock_instance.share_cube_with_user.return_value = True - - # Configure the mock to return our mock instance - mock_get_mos.return_value = mock_instance - - yield mock_instance - - -def test_configure_error(mock_mos): - """Test configuration endpoint with error.""" - with patch("memos.api.start_api.MOS_INSTANCE", None): - response = client.post("/configure", json={}) - assert response.status_code == 422 # FastAPI validation error - - -def test_create_user(mock_mos): - """Test user creation endpoint.""" - response = client.post("/users", json=MOCK_USER_CREATE) - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "User created successfully", - "data": {"user_id": "test_user"}, - } - mock_mos.create_user.assert_called_once_with( - user_id="test_user", role=UserRole.USER, user_name="Test User" - ) - - -def test_create_user_validation_error(mock_mos): - """Test user creation with validation error.""" - mock_mos.create_user.side_effect = ValueError("Invalid user data") - response = client.post("/users", json=MOCK_USER_CREATE) - assert response.status_code == 400 - assert "Invalid user data" in response.json()["message"] - - -def test_list_users(mock_mos): - """Test list users endpoint.""" - # Set up mock to return the expected data structure - mock_users = [ - { - "user_id": "test_user", - "user_name": "Test User", - "role": "user", - "created_at": "2023-01-01T00:00:00", - "is_active": True, - } - ] - mock_mos.list_users.return_value = mock_users - - response = client.get("/users") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Users retrieved successfully", - "data": mock_users, - } - mock_mos.list_users.assert_called_once() - - -def test_get_user_info(mock_mos): - """Test get user info endpoint.""" - # Set up mock to return the expected data structure - mock_user_info = { - "user_id": "test_user", - "user_name": "Test User", - "role": "user", - "created_at": "2023-01-01T00:00:00", - "accessible_cubes": [], - } - mock_mos.get_user_info.return_value = mock_user_info - - response = client.get("/users/me") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "User info retrieved successfully", - "data": mock_user_info, - } - mock_mos.get_user_info.assert_called_once() - - -def test_register_mem_cube(mock_mos): - """Test MemCube registration endpoint.""" - response = client.post("/mem_cubes", json=MOCK_MEMCUBE_REGISTER) - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "MemCube registered successfully", - "data": None, - } - mock_mos.register_mem_cube.assert_called_once_with( - mem_cube_name_or_path="test_cube_path", mem_cube_id="test_cube", user_id="test_user" - ) - - -def test_register_mem_cube_validation_error(mock_mos): - """Test MemCube registration with validation error.""" - mock_mos.register_mem_cube.side_effect = ValueError("Invalid MemCube") - response = client.post("/mem_cubes", json=MOCK_MEMCUBE_REGISTER) - assert response.status_code == 400 - assert "Invalid MemCube" in response.json()["message"] - - -def test_unregister_mem_cube(mock_mos): - """Test MemCube unregistration endpoint.""" - response = client.delete("/mem_cubes/test_cube?user_id=test_user") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "MemCube unregistered successfully", - "data": None, - } - mock_mos.unregister_mem_cube.assert_called_once_with( - mem_cube_id="test_cube", user_id="test_user" - ) - - -def test_unregister_nonexistent_mem_cube(mock_mos): - """Test unregistering a non-existent MemCube.""" - mock_mos.unregister_mem_cube.side_effect = ValueError("MemCube not found") - response = client.delete("/mem_cubes/nonexistent_cube") - assert response.status_code == 400 - assert "MemCube not found" in response.json()["message"] - - -def test_share_cube(mock_mos): - """Test cube sharing endpoint.""" - response = client.post("/mem_cubes/test_cube/share", json=MOCK_CUBE_SHARE) - assert response.status_code == 200 - assert response.json() == {"code": 200, "message": "Cube shared successfully", "data": None} - mock_mos.share_cube_with_user.assert_called_once_with("test_cube", "target_user") - - -def test_share_cube_failure(mock_mos): - """Test cube sharing failure.""" - mock_mos.share_cube_with_user.return_value = False - response = client.post("/mem_cubes/test_cube/share", json=MOCK_CUBE_SHARE) - assert response.status_code == 400 - assert "Failed to share cube" in response.json()["message"] - - -@pytest.mark.parametrize( - "memory_create,expected_calls", - [ - (MOCK_MEMORY_CREATE, {"messages": [MOCK_MESSAGE]}), - (MOCK_MEMORY_CONTENT, {"memory_content": "test memory content"}), - (MOCK_DOC_PATH, {"doc_path": "/path/to/doc"}), - ], -) -def test_add_memory(mock_mos, memory_create, expected_calls): - """Test adding memories with different types of content.""" - response = client.post("/memories", json=memory_create) - assert response.status_code == 200 - assert response.json() == {"code": 200, "message": "Memories added successfully", "data": None} - mock_mos.add.assert_called_once() - - -def test_add_memory_validation_error(mock_mos): - """Test adding memory with validation error.""" - response = client.post("/memories", json={}) - assert response.status_code == 400 - assert "must be provided" in response.json()["message"] - - -def test_get_all_memories(mock_mos): - """Test get all memories endpoint.""" - mock_results = { - "text_mem": [{"cube_id": "test_cube", "memories": []}], - "act_mem": [], - "para_mem": [], - } - mock_mos.get_all.return_value = mock_results - - response = client.get("/memories") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Memories retrieved successfully", - "data": mock_results, - } - mock_mos.get_all.assert_called_once_with(mem_cube_id=None, user_id=None) - - -def test_get_memory(mock_mos): - """Test get specific memory endpoint.""" - mock_memory = {"memory": "test memory content"} - mock_mos.get.return_value = mock_memory - - response = client.get("/memories/test_cube/test_memory") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Memory retrieved successfully", - "data": mock_memory, - } - mock_mos.get.assert_called_once_with( - mem_cube_id="test_cube", memory_id="test_memory", user_id=None - ) - - -def test_get_nonexistent_memory(mock_mos): - """Test getting a non-existent memory.""" - mock_mos.get.side_effect = ValueError("Memory not found") - response = client.get("/memories/test_cube/nonexistent_memory") - assert response.status_code == 400 - assert "Memory not found" in response.json()["message"] - - -def test_search_memories(mock_mos): - """Test search memories endpoint.""" - # Mock the search method to return a proper result structure - mock_results = {"text_mem": [], "act_mem": [], "para_mem": []} - mock_mos.search.return_value = mock_results - - # Ensure the search request has all required fields - search_request = { - "query": "test query", - "user_id": "test_user", - "install_cube_ids": ["test_cube"], - } - - response = client.post("/search", json=search_request) - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Search completed successfully", - "data": mock_results, - } - mock_mos.search.assert_called_once_with( - query="test query", user_id="test_user", install_cube_ids=["test_cube"] - ) - - -def test_update_memory(mock_mos): - """Test updating a memory endpoint.""" - update_data = {"content": "updated content"} - response = client.put("/memories/test_cube/test_memory?user_id=test_user", json=update_data) - assert response.status_code == 200 - assert response.json() == {"code": 200, "message": "Memory updated successfully", "data": None} - mock_mos.update.assert_called_once_with( - mem_cube_id="test_cube", - memory_id="test_memory", - text_memory_item=update_data, - user_id="test_user", - ) - - -def test_update_nonexistent_memory(mock_mos): - """Test updating a non-existent memory.""" - mock_mos.update.side_effect = ValueError("Memory not found") - response = client.put("/memories/test_cube/nonexistent_memory", json={}) - assert response.status_code == 400 - assert "Memory not found" in response.json()["message"] - - -def test_delete_memory(mock_mos): - """Test deleting a memory endpoint.""" - response = client.delete("/memories/test_cube/test_memory?user_id=test_user") - assert response.status_code == 200 - assert response.json() == {"code": 200, "message": "Memory deleted successfully", "data": None} - mock_mos.delete.assert_called_once_with( - mem_cube_id="test_cube", memory_id="test_memory", user_id="test_user" - ) - - -def test_delete_nonexistent_memory(mock_mos): - """Test deleting a non-existent memory.""" - mock_mos.delete.side_effect = ValueError("Memory not found") - response = client.delete("/memories/test_cube/nonexistent_memory") - assert response.status_code == 400 - assert "Memory not found" in response.json()["message"] - - -def test_delete_all_memories(mock_mos): - """Test deleting all memories endpoint.""" - response = client.delete("/memories/test_cube?user_id=test_user") - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "All memories deleted successfully", - "data": None, - } - mock_mos.delete_all.assert_called_once_with(mem_cube_id="test_cube", user_id="test_user") - - -def test_delete_all_nonexistent_memories(mock_mos): - """Test deleting all memories from non-existent MemCube.""" - mock_mos.delete_all.side_effect = ValueError("MemCube not found") - response = client.delete("/memories/nonexistent_cube") - assert response.status_code == 400 - assert "MemCube not found" in response.json()["message"] - - -def test_chat(mock_mos): - """Test chat endpoint.""" - response = client.post("/chat", json=MOCK_CHAT_REQUEST) - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Chat response generated", - "data": "test response", - } - mock_mos.chat.assert_called_once_with(query="test chat query", user_id="test_user") - - -def test_chat_without_user_id(mock_mos): - """Test chat endpoint without user_id.""" - chat_request = {"query": "test chat query"} - response = client.post("/chat", json=chat_request) - assert response.status_code == 200 - assert response.json() == { - "code": 200, - "message": "Chat response generated", - "data": "test response", - } - mock_mos.chat.assert_called_once_with(query="test chat query", user_id=None) - - -def test_home_redirect(): - """Test home endpoint redirects to docs.""" - response = client.get("/", follow_redirects=False) - assert response.status_code == 307 - assert response.headers["location"] == "/docs" diff --git a/tests/test_cli.py b/tests/test_cli.py index 9750af121..a1e423e4f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -16,13 +16,13 @@ class TestExportOpenAPI: """Test the export_openapi function.""" - @patch("memos.api.start_api.app") + @patch("memos.cli.get_openapi_app") @patch("builtins.open", new_callable=mock_open) @patch("os.makedirs") def test_export_openapi_success(self, mock_makedirs, mock_file, mock_app): """Test successful OpenAPI export.""" mock_openapi_data = {"openapi": "3.0.0", "info": {"title": "Test API"}} - mock_app.openapi.return_value = mock_openapi_data + mock_app.return_value.openapi.return_value = mock_openapi_data result = export_openapi("/test/path/openapi.json") @@ -30,11 +30,11 @@ def test_export_openapi_success(self, mock_makedirs, mock_file, mock_app): mock_makedirs.assert_called_once_with("/test/path", exist_ok=True) mock_file.assert_called_once_with("/test/path/openapi.json", "w") - @patch("memos.api.start_api.app") + @patch("memos.cli.get_openapi_app") @patch("builtins.open", side_effect=OSError("Permission denied")) def test_export_openapi_error(self, mock_file, mock_app): """Test OpenAPI export when file writing fails.""" - mock_app.openapi.return_value = {"test": "data"} + mock_app.return_value.openapi.return_value = {"test": "data"} with pytest.raises(IOError): export_openapi("/invalid/path/openapi.json")