diff --git a/README.md b/README.md index a7b05d683..1c7d5bd93 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ - [**72% lower token usage**](https://x.com/MemOS_dev/status/2020854044583924111) โ€” intelligent memory retrieval instead of loading full chat history - [**Multi-agent memory sharing**](https://x.com/MemOS_dev/status/2020538135487062094) โ€” multi-instance agents share memory via same user_id, automatic context handoff -Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/) +Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/) Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin) ### ๐Ÿง  Local Plugin โ€” 100% On-Device Memory @@ -84,7 +84,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem - **Hybrid search + task & skill evolution** โ€” FTS5 + vector search, auto task summarization, reusable skills that self-upgrade - **Multi-agent collaboration + Memory Viewer** โ€” memory isolation, skill sharing, full web dashboard with 7 management pages - ๐ŸŒ [Homepage](https://memos-claw.openmem.net) ยท + ๐ŸŒ [Homepage](https://memos-claw.openmem.net) ยท ๐Ÿ“– [Documentation](https://memos-claw.openmem.net/docs/index.html) ยท ๐Ÿ“ฆ [NPM](https://www.npmjs.com/package/@memtensor/memos-local-openclaw-plugin) ## ๐Ÿ“Œ MemOS: Memory Operating System for AI Agents @@ -104,10 +104,10 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem ### News -- **2026-03-08** ยท ๐Ÿฆž **MemOS OpenClaw Plugin โ€” Cloud & Local** +- **2026-03-08** ยท ๐Ÿฆž **MemOS OpenClaw Plugin โ€” Cloud & Local** Official OpenClaw memory plugins launched. **Cloud Plugin**: hosted memory service with 72% lower token usage and multi-agent memory sharing ([MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin)). **Local Plugin** (`v1.0.0`): 100% on-device memory with persistent SQLite, hybrid search (FTS5 + vector), task summarization & skill evolution, multi-agent collaboration, and a full Memory Viewer dashboard. -- **2025-12-24** ยท ๐ŸŽ‰ **MemOS v2.0: Stardust (ๆ˜Ÿๅฐ˜) Release** +- **2025-12-24** ยท ๐ŸŽ‰ **MemOS v2.0: Stardust (ๆ˜Ÿๅฐ˜) Release** Comprehensive KB (doc/URL parsing + cross-project sharing), memory feedback & precise deletion, multi-modal memory (images/charts), tool memory for agent planning, Redis Streams scheduling + DB optimizations, streaming/non-streaming chat, MCP upgrade, and lightweight quick/full deployment.
โœจ New Features @@ -155,7 +155,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- **2025-08-07** ยท ๐ŸŽ‰ **MemOS v1.0.0 (MemCube) Release** - First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, NebulaGraph support, improved search capabilities, and the official Playground launch. + First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, improved search capabilities, and the official Playground launch.
โœจ New Features @@ -176,7 +176,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem **Plaintext Memory** - Integrated internet search with Bocha. - - Added support for Nebula database. + - Expanded graph database support. - Added contextual understanding for the tree-structured plaintext memory search interface.
@@ -188,7 +188,7 @@ Full tutorial โ†’ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem - Fixed the concat_cache method. **Plaintext Memory** - - Fixed Nebula search-related issues. + - Fixed graph search-related issues. diff --git a/docker/.env.example b/docker/.env.example index 3674cd69b..08f9cf1db 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -127,7 +127,7 @@ MEMSCHEDULER_USE_REDIS_QUEUE=false ## Graph / vector stores # Neo4j database selection mode -NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | nebular | polardb +NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | polardb | postgres # Neo4j database url NEO4J_URI=bolt://localhost:7687 # required when backend=neo4j* # Neo4j database user diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py index 6dbe202c2..d14b6d687 100644 --- a/examples/mem_agent/deepsearch_example.py +++ b/examples/mem_agent/deepsearch_example.py @@ -47,7 +47,7 @@ def build_minimal_components(): # Build component configurations using APIConfig methods (like config_builders.py) - # Graph DB configuration - using APIConfig.get_nebular_config() + # Graph DB configuration - using APIConfig graph DB helpers graph_db_backend = os.getenv("NEO4J_BACKEND", "polardb").lower() graph_db_backend_map = { "polardb": APIConfig.get_polardb_config(), diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 87f1efd8e..2e8bae57c 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -741,21 +741,6 @@ def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]: "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), } - @staticmethod - def get_nebular_config(user_id: str | None = None) -> dict[str, Any]: - """Get Nebular configuration.""" - return { - "uri": json.loads(os.getenv("NEBULAR_HOSTS", '["localhost"]')), - "user": os.getenv("NEBULAR_USER", "root"), - "password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"), - "space": os.getenv("NEBULAR_SPACE", "shared-tree-textual-memory"), - "user_name": f"memos{user_id.replace('-', '')}", - "use_multi_db": False, - "auto_create": True, - "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)), - } - - @staticmethod def get_milvus_config(): return { "collection_name": [ @@ -1103,7 +1088,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene neo4j_community_config = APIConfig.get_neo4j_community_config(user_id) neo4j_config = APIConfig.get_neo4j_config(user_id) - nebular_config = APIConfig.get_nebular_config(user_id) polardb_config = APIConfig.get_polardb_config(user_id) internet_config = ( APIConfig.get_internet_config() @@ -1114,7 +1098,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, - "nebular": nebular_config, "polardb": polardb_config, "postgres": postgres_config, } @@ -1144,9 +1127,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene "reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower() == "true", "memory_size": { - "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)), - "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)), - "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)), + "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)), + "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)), + "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)), }, "search_strategy": { "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), @@ -1169,7 +1152,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene } ) else: - raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}") + raise ValueError(f"Invalid graph DB backend: {graph_db_backend}") default_mem_cube = GeneralMemCube(default_cube_config) return default_config, default_mem_cube @@ -1188,13 +1171,11 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": openai_config = APIConfig.get_openai_config() neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default") neo4j_config = APIConfig.get_neo4j_config(user_id="default") - nebular_config = APIConfig.get_nebular_config(user_id="default") polardb_config = APIConfig.get_polardb_config(user_id="default") postgres_config = APIConfig.get_postgres_config(user_id="default") graph_db_backend_map = { "neo4j-community": neo4j_community_config, "neo4j": neo4j_config, - "nebular": nebular_config, "polardb": polardb_config, "postgres": postgres_config, } @@ -1227,9 +1208,9 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": == "true", "internet_retriever": internet_config, "memory_size": { - "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)), - "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)), - "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)), + "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)), + "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)), + "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)), }, "search_strategy": { "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), @@ -1253,4 +1234,4 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None": } ) else: - raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}") + raise ValueError(f"Invalid graph DB backend: {graph_db_backend}") diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index 5655bf1e5..d29429fc9 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -39,13 +39,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: graph_db_backend_map = { "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), "neo4j": APIConfig.get_neo4j_config(user_id=user_id), - "nebular": APIConfig.get_nebular_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), "postgres": APIConfig.get_postgres_config(user_id=user_id), } # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars - graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower() + graph_db_backend = os.getenv( + "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community") + ).lower() return GraphDBConfigFactory.model_validate( { "backend": graph_db_backend, diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py index 5900d2357..98de09812 100644 --- a/src/memos/configs/graph_db.py +++ b/src/memos/configs/graph_db.py @@ -103,57 +103,6 @@ def validate_community(self): return self -class NebulaGraphDBConfig(BaseGraphDBConfig): - """ - NebulaGraph-specific configuration. - - Key concepts: - - `space`: Equivalent to a database or namespace. All tag/edge/schema live within a space. - - `user_name`: Used for logical tenant isolation if needed. - - `auto_create`: Whether to automatically create the target space if it does not exist. - - Example: - --- - hosts = ["127.0.0.1:9669"] - user = "root" - password = "nebula" - space = "shared_graph" - user_name = "alice" - """ - - space: str = Field( - ..., description="The name of the target NebulaGraph space (like a database)" - ) - user_name: str | None = Field( - default=None, - description="Logical user or tenant ID for data isolation (optional, used in metadata tagging)", - ) - auto_create: bool = Field( - default=False, - description="Whether to auto-create the space if it does not exist", - ) - use_multi_db: bool = Field( - default=True, - description=( - "If True: use Neo4j's multi-database feature for physical isolation; " - "each user typically gets a separate database. " - "If False: use a single shared database with logical isolation by user_name." - ), - ) - max_client: int = Field( - default=1000, - description=("max_client"), - ) - embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding") - - @model_validator(mode="after") - def validate_config(self): - """Validate config.""" - if not self.space: - raise ValueError("`space` must be provided") - return self - - class PolarDBGraphDBConfig(BaseConfig): """ PolarDB-specific configuration. @@ -299,7 +248,6 @@ class GraphDBConfigFactory(BaseModel): backend_to_class: ClassVar[dict[str, Any]] = { "neo4j": Neo4jGraphDBConfig, "neo4j-community": Neo4jCommunityGraphDBConfig, - "nebular": NebulaGraphDBConfig, "polardb": PolarDBGraphDBConfig, "postgres": PostgresGraphDBConfig, } diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py index c207e3190..93b5971ec 100644 --- a/src/memos/graph_dbs/factory.py +++ b/src/memos/graph_dbs/factory.py @@ -2,7 +2,6 @@ from memos.configs.graph_db import GraphDBConfigFactory from memos.graph_dbs.base import BaseGraphDB -from memos.graph_dbs.nebular import NebulaGraphDB from memos.graph_dbs.neo4j import Neo4jGraphDB from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB from memos.graph_dbs.polardb import PolarDBGraphDB @@ -15,7 +14,6 @@ class GraphStoreFactory(BaseGraphDB): backend_to_class: ClassVar[dict[str, Any]] = { "neo4j": Neo4jGraphDB, "neo4j-community": Neo4jCommunityGraphDB, - "nebular": NebulaGraphDB, "polardb": PolarDBGraphDB, "postgres": PostgresGraphDB, } diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py deleted file mode 100644 index 428d6d09e..000000000 --- a/src/memos/graph_dbs/nebular.py +++ /dev/null @@ -1,1794 +0,0 @@ -import json -import traceback - -from contextlib import suppress -from datetime import datetime -from threading import Lock -from typing import TYPE_CHECKING, Any, ClassVar, Literal - -import numpy as np - -from memos.configs.graph_db import NebulaGraphDBConfig -from memos.dependency import require_python_package -from memos.graph_dbs.base import BaseGraphDB -from memos.log import get_logger -from memos.utils import timed - - -if TYPE_CHECKING: - from nebulagraph_python import ( - NebulaClient, - ) - - -logger = get_logger(__name__) - - -_TRANSIENT_ERR_KEYS = ( - "Session not found", - "Connection not established", - "timeout", - "deadline exceeded", - "Broken pipe", - "EOFError", - "socket closed", - "connection reset", - "connection refused", -) - - -@timed -def _normalize(vec: list[float]) -> list[float]: - v = np.asarray(vec, dtype=np.float32) - norm = np.linalg.norm(v) - return (v / (norm if norm else 1.0)).tolist() - - -@timed -def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]: - node_id = item["id"] - memory = item["memory"] - metadata = item.get("metadata", {}) - return node_id, memory, metadata - - -@timed -def _escape_str(value: str) -> str: - out = [] - for ch in value: - code = ord(ch) - if ch == "\\": - out.append("\\\\") - elif ch == '"': - out.append('\\"') - elif ch == "\n": - out.append("\\n") - elif ch == "\r": - out.append("\\r") - elif ch == "\t": - out.append("\\t") - elif ch == "\b": - out.append("\\b") - elif ch == "\f": - out.append("\\f") - elif code < 0x20 or code in (0x2028, 0x2029): - out.append(f"\\u{code:04x}") - else: - out.append(ch) - return "".join(out) - - -@timed -def _format_datetime(value: str | datetime) -> str: - """Ensure datetime is in ISO 8601 format string.""" - if isinstance(value, datetime): - return value.isoformat() - return str(value) - - -@timed -def _normalize_datetime(val): - """ - Normalize datetime to ISO 8601 UTC string with +00:00. - - If val is datetime object -> keep isoformat() (Neo4j) - - If val is string without timezone -> append +00:00 (Nebula) - - Otherwise just str() - """ - if hasattr(val, "isoformat"): - return val.isoformat() - if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")): - return val + "+08:00" - return str(val) - - -class NebulaGraphDB(BaseGraphDB): - """ - NebulaGraph-based implementation of a graph memory store. - """ - - # ====== shared pool cache & refcount ====== - # These are process-local; in a multi-process model each process will - # have its own cache. - _CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {} - _CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {} - _CLIENT_LOCK: ClassVar[Lock] = Lock() - _CLIENT_INIT_DONE: ClassVar[set[str]] = set() - - @staticmethod - def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]: - hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None) - if isinstance(hosts, str): - return [hosts] - return list(hosts or []) - - @staticmethod - def _make_client_key(cfg: NebulaGraphDBConfig) -> str: - hosts = NebulaGraphDB._get_hosts_from_cfg(cfg) - return "|".join( - [ - "nebula-sync", - ",".join(hosts), - str(getattr(cfg, "user", "")), - str(getattr(cfg, "space", "")), - ] - ) - - @classmethod - def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB": - tmp = object.__new__(NebulaGraphDB) - tmp.config = cfg - tmp.db_name = cfg.space - tmp.user_name = None - tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072) - tmp.default_memory_dimension = 3072 - tmp.common_fields = { - "id", - "memory", - "user_name", - "user_id", - "session_id", - "status", - "key", - "confidence", - "tags", - "created_at", - "updated_at", - "memory_type", - "sources", - "source", - "node_type", - "visibility", - "usage", - "background", - } - tmp.base_fields = set(tmp.common_fields) - {"usage"} - tmp.heavy_fields = {"usage"} - tmp.dim_field = ( - f"embedding_{tmp.embedding_dimension}" - if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension) - else "embedding" - ) - tmp.system_db_name = cfg.space - tmp._client = client - tmp._owns_client = False - return tmp - - @classmethod - def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]: - from nebulagraph_python import ( - ConnectionConfig, - NebulaClient, - SessionConfig, - SessionPoolConfig, - ) - - key = cls._make_client_key(cfg) - with cls._CLIENT_LOCK: - client = cls._CLIENT_CACHE.get(key) - if client is None: - # Connection setting - - tmp_client = NebulaClient( - hosts=cfg.uri, - username=cfg.user, - password=cfg.password, - session_config=SessionConfig(graph=None), - session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000), - ) - try: - cls._ensure_space_exists(tmp_client, cfg) - finally: - tmp_client.close() - - conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None) - if conn_conf is None: - conn_conf = ConnectionConfig.from_defults( - cls._get_hosts_from_cfg(cfg), - getattr(cfg, "ssl_param", None), - ) - - sess_conf = SessionConfig(graph=getattr(cfg, "space", None)) - pool_conf = SessionPoolConfig( - size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000 - ) - - client = NebulaClient( - hosts=conn_conf.hosts, - username=cfg.user, - password=cfg.password, - conn_config=conn_conf, - session_config=sess_conf, - session_pool_config=pool_conf, - ) - cls._CLIENT_CACHE[key] = client - cls._CLIENT_REFCOUNT[key] = 0 - logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}") - - cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1 - - if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE: - try: - pass - finally: - pass - - if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE: - with cls._CLIENT_LOCK: - if key not in cls._CLIENT_INIT_DONE: - admin = cls._bootstrap_admin(cfg, client) - try: - admin._ensure_database_exists() - admin._create_basic_property_indexes() - admin._create_vector_index( - dimensions=int( - admin.embedding_dimension or admin.default_memory_dimension - ), - ) - cls._CLIENT_INIT_DONE.add(key) - logger.info("[NebulaGraphDBSync] One-time init done") - except Exception: - logger.exception("[NebulaGraphDBSync] One-time init failed") - - return key, client - - def _refresh_client(self): - """ - refresh NebulaClient: - """ - old_key = getattr(self, "_client_key", None) - if not old_key: - return - - cls = self.__class__ - with cls._CLIENT_LOCK: - try: - if old_key in cls._CLIENT_CACHE: - try: - cls._CLIENT_CACHE[old_key].close() - except Exception as e: - logger.warning(f"[refresh_client] close old client error: {e}") - finally: - cls._CLIENT_CACHE.pop(old_key, None) - finally: - cls._CLIENT_REFCOUNT[old_key] = 0 - - new_key, new_client = cls._get_or_create_shared_client(self.config) - self._client_key = new_key - self._client = new_client - logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}") - - @classmethod - def _release_shared_client(cls, key: str): - with cls._CLIENT_LOCK: - if key not in cls._CLIENT_CACHE: - return - cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1) - if cls._CLIENT_REFCOUNT[key] == 0: - try: - cls._CLIENT_CACHE[key].close() - except Exception as e: - logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}") - finally: - cls._CLIENT_CACHE.pop(key, None) - cls._CLIENT_REFCOUNT.pop(key, None) - logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}") - - @classmethod - def close_all_shared_clients(cls): - with cls._CLIENT_LOCK: - for key, client in list(cls._CLIENT_CACHE.items()): - try: - client.close() - except Exception as e: - logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}") - finally: - logger.info(f"[NebulaGraphDBSync] Closed client key={key}") - cls._CLIENT_CACHE.clear() - cls._CLIENT_REFCOUNT.clear() - - @require_python_package( - import_name="nebulagraph_python", - install_command="pip install nebulagraph-python>=5.1.1", - install_link=".....", - ) - def __init__(self, config: NebulaGraphDBConfig): - """ - NebulaGraph DB client initialization. - - Required config attributes: - - hosts: list[str] like ["host1:port", "host2:port"] - - user: str - - password: str - - db_name: str (optional for basic commands) - - Example config: - { - "hosts": ["xxx.xx.xx.xxx:xxxx"], - "user": "root", - "password": "nebula", - "space": "test" - } - """ - - assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED" - self.config = config - self.db_name = config.space - self.user_name = config.user_name - self.embedding_dimension = config.embedding_dimension - self.default_memory_dimension = 3072 - self.common_fields = { - "id", - "memory", - "user_name", - "user_id", - "session_id", - "status", - "key", - "confidence", - "tags", - "created_at", - "updated_at", - "memory_type", - "sources", - "source", - "node_type", - "visibility", - "usage", - "background", - } - self.base_fields = set(self.common_fields) - {"usage"} - self.heavy_fields = {"usage"} - self.dim_field = ( - f"embedding_{self.embedding_dimension}" - if (str(self.embedding_dimension) != str(self.default_memory_dimension)) - else "embedding" - ) - self.system_db_name = config.space - - # ---- NEW: pool acquisition strategy - # Get or create a shared pool from the class-level cache - self._client_key, self._client = self._get_or_create_shared_client(config) - self._owns_client = True - - logger.info("Connected to NebulaGraph successfully.") - - @timed - def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True): - def _wrap_use_db(q: str) -> str: - if auto_set_db and self.db_name: - return f"USE `{self.db_name}`\n{q}" - return q - - try: - return self._client.execute(_wrap_use_db(gql), timeout=timeout) - - except Exception as e: - emsg = str(e) - if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS): - logger.warning(f"[execute_query] {e!s} โ†’ refreshing session pool and retry once...") - try: - self._refresh_client() - return self._client.execute(_wrap_use_db(gql), timeout=timeout) - except Exception: - logger.exception("[execute_query] retry after refresh failed") - raise - raise - - @timed - def close(self): - """ - Close the connection resource if this instance owns it. - - - If pool was injected (`shared_pool`), do nothing. - - If pool was acquired via shared cache, decrement refcount and close - when the last owner releases it. - """ - if not self._owns_client: - logger.debug("[NebulaGraphDBSync] close() skipped (injected client).") - return - if self._client_key: - self._release_shared_client(self._client_key) - self._client_key = None - self._client = None - - # NOTE: __del__ is best-effort; do not rely on GC order. - def __del__(self): - with suppress(Exception): - self.close() - - @timed - def create_index( - self, - label: str = "Memory", - vector_property: str = "embedding", - dimensions: int = 3072, - index_name: str = "memory_vector_index", - ) -> None: - # Create vector index - self._create_vector_index(label, vector_property, dimensions, index_name) - # Create indexes - self._create_basic_property_indexes() - - @timed - def remove_oldest_memory( - self, memory_type: str, keep_latest: int, user_name: str | None = None - ) -> None: - """ - Remove all WorkingMemory nodes except the latest `keep_latest` entries. - - Args: - memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). - keep_latest (int): Number of latest WorkingMemory entries to keep. - user_name(str): optional user_name. - """ - try: - user_name = user_name if user_name else self.config.user_name - optional_condition = f"AND n.user_name = '{user_name}'" - count = self.count_nodes(memory_type, user_name) - if count > keep_latest: - delete_query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE n.memory_type = '{memory_type}' - {optional_condition} - ORDER BY n.updated_at DESC - OFFSET {int(keep_latest)} - DETACH DELETE n - """ - self.execute_query(delete_query) - except Exception as e: - logger.warning(f"Delete old mem error: {e}") - - @timed - def add_node( - self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None - ) -> None: - """ - Insert or update a Memory node in NebulaGraph. - """ - metadata["user_name"] = user_name if user_name else self.config.user_name - now = datetime.utcnow() - metadata = metadata.copy() - metadata.setdefault("created_at", now) - metadata.setdefault("updated_at", now) - metadata["node_type"] = metadata.pop("type") - metadata["id"] = id - metadata["memory"] = memory - - if "embedding" in metadata and isinstance(metadata["embedding"], list): - assert len(metadata["embedding"]) == self.embedding_dimension, ( - f"input embedding dimension must equal to {self.embedding_dimension}" - ) - embedding = metadata.pop("embedding") - metadata[self.dim_field] = _normalize(embedding) - - metadata = self._metadata_filter(metadata) - properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()) - gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})" - - try: - self.execute_query(gql) - logger.info("insert success") - except Exception as e: - logger.error( - f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}" - ) - - @timed - def node_not_exist(self, scope: str, user_name: str | None = None) -> int: - user_name = user_name if user_name else self.config.user_name - filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"' - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE {filter_clause} - RETURN n.id AS id - LIMIT 1 - """ - - try: - result = self.execute_query(query) - return result.size == 0 - except Exception as e: - logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) - raise - - @timed - def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None: - """ - Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present. - """ - user_name = user_name if user_name else self.config.user_name - fields = fields.copy() - set_clauses = [] - for k, v in fields.items(): - set_clauses.append(f"n.{k} = {self._format_value(v, k)}") - - set_clause_str = ",\n ".join(set_clauses) - - query = f""" - MATCH (n@Memory {{id: "{id}"}}) - """ - query += f'WHERE n.user_name = "{user_name}"' - - query += f"\nSET {set_clause_str}" - self.execute_query(query) - - @timed - def delete_node(self, id: str, user_name: str | None = None) -> None: - """ - Delete a node from the graph. - Args: - id: Node identifier to delete. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)} - DETACH DELETE n - """ - self.execute_query(query) - - @timed - def add_edge(self, source_id: str, target_id: str, type: str, user_name: str | None = None): - """ - Create an edge from source node to target node. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type (e.g., 'RELATE_TO', 'PARENT'). - user_name (str, optional): User name for filtering in non-multi-db mode - """ - if not source_id or not target_id: - raise ValueError("[add_edge] source_id and target_id must be provided") - user_name = user_name if user_name else self.config.user_name - props = "" - props = f'{{user_name: "{user_name}"}}' - insert_stmt = f''' - MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) - INSERT (a) -[e@{type} {props}]-> (b) - ''' - try: - self.execute_query(insert_stmt) - except Exception as e: - logger.error(f"Failed to insert edge: {e}", exc_info=True) - - @timed - def delete_edge( - self, source_id: str, target_id: str, type: str, user_name: str | None = None - ) -> None: - """ - Delete a specific edge between two nodes. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type to remove. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (a@Memory) -[r@{type}]-> (b@Memory) - WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)} - """ - - query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" - query += "\nDELETE r" - self.execute_query(query) - - @timed - def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int: - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = "{memory_type}" - """ - query += f"\nAND n.user_name = '{user_name}'" - query += "\nRETURN COUNT(n) AS count" - - try: - result = self.execute_query(query) - return result.one_or_none()["count"].value - except Exception as e: - logger.error(f"[get_memory_count] Failed: {e}") - return -1 - - @timed - def count_nodes(self, scope: str, user_name: str | None = None) -> int: - user_name = user_name if user_name else self.config.user_name - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = "{scope}" - """ - query += f"\nAND n.user_name = '{user_name}'" - query += "\nRETURN count(n) AS count" - - result = self.execute_query(query) - return result.one_or_none()["count"].value - - @timed - def edge_exists( - self, - source_id: str, - target_id: str, - type: str = "ANY", - direction: str = "OUTGOING", - user_name: str | None = None, - ) -> bool: - """ - Check if an edge exists between two nodes. - Args: - source_id: ID of the source node. - target_id: ID of the target node. - type: Relationship type. Use "ANY" to match any relationship type. - direction: Direction of the edge. - Use "OUTGOING" (default), "INCOMING", or "ANY". - user_name (str, optional): User name for filtering in non-multi-db mode - Returns: - True if the edge exists, otherwise False. - """ - # Prepare the relationship pattern - user_name = user_name if user_name else self.config.user_name - rel = "r" if type == "ANY" else f"r@{type}" - - # Prepare the match pattern with direction - if direction == "OUTGOING": - pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]->(b@Memory {{id: '{target_id}'}})" - elif direction == "INCOMING": - pattern = f"(a@Memory {{id: '{source_id}'}})<-[{rel}]-(b@Memory {{id: '{target_id}'}})" - elif direction == "ANY": - pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]-(b@Memory {{id: '{target_id}'}})" - else: - raise ValueError( - f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." - ) - query = f"MATCH {pattern}" - query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" - query += "\nRETURN r" - - # Run the Cypher query - result = self.execute_query(query) - record = result.one_or_none() - if record is None: - return False - return record.values() is not None - - @timed - # Graph Query & Reasoning - def get_node( - self, id: str, include_embedding: bool = False, user_name: str | None = None - ) -> dict[str, Any] | None: - """ - Retrieve a Memory node by its unique ID. - - Args: - id (str): Node ID (Memory.id) - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - dict: Node properties as key-value pairs, or None if not found. - """ - filter_clause = f'n.id = "{id}"' - return_fields = self._build_return_fields(include_embedding) - gql = f""" - MATCH (n@Memory) - WHERE {filter_clause} - RETURN {return_fields} - """ - - try: - result = self.execute_query(gql) - for row in result: - props = {k: v.value for k, v in row.items()} - node = self._parse_node(props) - return node - - except Exception as e: - logger.error( - f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}" - ) - return None - - @timed - def get_nodes( - self, - ids: list[str], - include_embedding: bool = False, - user_name: str | None = None, - **kwargs, - ) -> list[dict[str, Any]]: - """ - Retrieve the metadata and memory of a list of nodes. - Args: - ids: List of Node identifier. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - Returns: - list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'. - - Notes: - - Assumes all provided IDs are valid and exist. - - Returns empty list if input is empty. - """ - if not ids: - return [] - # Safe formatting of the ID list - id_list = ",".join(f'"{_id}"' for _id in ids) - - return_fields = self._build_return_fields(include_embedding) - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE n.id IN [{id_list}] - RETURN {return_fields} - """ - nodes = [] - try: - results = self.execute_query(query) - for row in results: - props = {k: v.value for k, v in row.items()} - nodes.append(self._parse_node(props)) - except Exception as e: - logger.error( - f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}" - ) - return nodes - - @timed - def get_edges( - self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None - ) -> list[dict[str, str]]: - """ - Get edges connected to a node, with optional type and direction filter. - - Args: - id: Node ID to retrieve edges for. - type: Relationship type to match, or 'ANY' to match all. - direction: 'OUTGOING', 'INCOMING', or 'ANY'. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - List of edges: - [ - {"from": "source_id", "to": "target_id", "type": "RELATE"}, - ... - ] - """ - # Build relationship type filter - rel_type = "" if type == "ANY" else f"@{type}" - user_name = user_name if user_name else self.config.user_name - # Build Cypher pattern based on direction - if direction == "OUTGOING": - pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "INCOMING": - pattern = f"(a@Memory)<-[r{rel_type}]-(b@Memory)" - where_clause = f"a.id = '{id}'" - elif direction == "ANY": - pattern = f"(a@Memory)-[r{rel_type}]-(b@Memory)" - where_clause = f"a.id = '{id}' OR b.id = '{id}'" - else: - raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - - where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'" - - query = f""" - MATCH {pattern} - WHERE {where_clause} - RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type - """ - - result = self.execute_query(query) - edges = [] - for record in result: - edges.append( - { - "from": record["from_id"].value, - "to": record["to_id"].value, - "type": record["edge_type"].value, - } - ) - return edges - - @timed - def get_neighbors_by_tag( - self, - tags: list[str], - exclude_ids: list[str], - top_k: int = 5, - min_overlap: int = 1, - include_embedding: bool = False, - user_name: str | None = None, - ) -> list[dict[str, Any]]: - """ - Find top-K neighbor nodes with maximum tag overlap. - - Args: - tags: The list of tags to match. - exclude_ids: Node IDs to exclude (e.g., local cluster). - top_k: Max number of neighbors to return. - min_overlap: Minimum number of overlapping tags required. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - List of dicts with node details and overlap count. - """ - if not tags: - return [] - user_name = user_name if user_name else self.config.user_name - where_clauses = [ - 'n.status = "activated"', - 'NOT (n.node_type = "reasoning")', - 'NOT (n.memory_type = "WorkingMemory")', - ] - if exclude_ids: - where_clauses.append(f"NOT (n.id IN {exclude_ids})") - - where_clauses.append(f'n.user_name = "{user_name}"') - - where_clause = " AND ".join(where_clauses) - tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]" - - return_fields = self._build_return_fields(include_embedding) - query = f""" - LET tag_list = {tag_list_literal} - - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE {where_clause} - RETURN {return_fields}, - size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count - ORDER BY overlap_count DESC - LIMIT {top_k} - """ - - result = self.execute_query(query) - neighbors: list[dict[str, Any]] = [] - for r in result: - props = {k: v.value for k, v in r.items() if k != "overlap_count"} - parsed = self._parse_node(props) - parsed["overlap_count"] = r["overlap_count"].value - neighbors.append(parsed) - - neighbors.sort(key=lambda x: x["overlap_count"], reverse=True) - neighbors = neighbors[:top_k] - result = [] - for neighbor in neighbors[:top_k]: - neighbor.pop("overlap_count") - result.append(neighbor) - return result - - @timed - def get_children_with_embeddings( - self, id: str, user_name: str | None = None - ) -> list[dict[str, Any]]: - user_name = user_name if user_name else self.config.user_name - where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" - - query = f""" - MATCH (p@Memory)-[@PARENT]->(c@Memory) - WHERE p.id = "{id}" {where_user} - RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory - """ - result = self.execute_query(query) - children = [] - for row in result: - eid = row["id"].value # STRING - emb_v = row[self.dim_field].value # NVector - emb = list(emb_v.values) if emb_v else [] - mem = row["memory"].value # STRING - - children.append({"id": eid, "embedding": emb, "memory": mem}) - return children - - @timed - def get_subgraph( - self, - center_id: str, - depth: int = 2, - center_status: str = "activated", - user_name: str | None = None, - ) -> dict[str, Any]: - """ - Retrieve a local subgraph centered at a given node. - Args: - center_id: The ID of the center node. - depth: The hop distance for neighbors. - center_status: Required status for center node. - user_name (str, optional): User name for filtering in non-multi-db mode - Returns: - { - "core_node": {...}, - "neighbors": [...], - "edges": [...] - } - """ - if not 1 <= depth <= 5: - raise ValueError("depth must be 1-5") - - user_name = user_name if user_name else self.config.user_name - - gql = f""" - MATCH (center@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE center.id = '{center_id}' - AND center.status = '{center_status}' - AND center.user_name = '{user_name}' - OPTIONAL MATCH p = (center)-[e]->{{1,{depth}}}(neighbor@Memory) - WHERE neighbor.user_name = '{user_name}' - RETURN center, - collect(DISTINCT neighbor) AS neighbors, - collect(EDGES(p)) AS edge_chains - """ - - result = self.execute_query(gql).one_or_none() - if not result or result.size == 0: - return {"core_node": None, "neighbors": [], "edges": []} - - core_node_props = result["center"].as_node().get_properties() - core_node = self._parse_node(core_node_props) - neighbors = [] - vid_to_id_map = {result["center"].as_node().node_id: core_node["id"]} - for n in result["neighbors"].value: - n_node = n.as_node() - n_props = n_node.get_properties() - node_parsed = self._parse_node(n_props) - neighbors.append(node_parsed) - vid_to_id_map[n_node.node_id] = node_parsed["id"] - - edges = [] - for chain_group in result["edge_chains"].value: - for edge_wr in chain_group.value: - edge = edge_wr.value - edges.append( - { - "type": edge.get_type(), - "source": vid_to_id_map.get(edge.get_src_id()), - "target": vid_to_id_map.get(edge.get_dst_id()), - } - ) - - return {"core_node": core_node, "neighbors": neighbors, "edges": edges} - - @timed - # Search / recall operations - def search_by_embedding( - self, - vector: list[float], - top_k: int = 5, - scope: str | None = None, - status: str | None = None, - threshold: float | None = None, - search_filter: dict | None = None, - user_name: str | None = None, - **kwargs, - ) -> list[dict]: - """ - Retrieve node IDs based on vector similarity. - - Args: - vector (list[float]): The embedding vector representing query semantics. - top_k (int): Number of top similar nodes to retrieve. - scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory'). - status (str, optional): Node status filter (e.g., 'active', 'archived'). - If provided, restricts results to nodes with matching status. - threshold (float, optional): Minimum similarity score threshold (0 ~ 1). - search_filter (dict, optional): Additional metadata filters for search results. - Keys should match node properties, values are the expected values. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[dict]: A list of dicts with 'id' and 'score', ordered by similarity. - - Notes: - - This method uses Neo4j native vector indexing to search for similar nodes. - - If scope is provided, it restricts results to nodes with matching memory_type. - - If 'status' is provided, only nodes with the matching status will be returned. - - If threshold is provided, only results with score >= threshold will be returned. - - If search_filter is provided, additional WHERE clauses will be added for metadata filtering. - - Typical use case: restrict to 'status = activated' to avoid - matching archived or merged nodes. - """ - user_name = user_name if user_name else self.config.user_name - vector = _normalize(vector) - dim = len(vector) - vector_str = ",".join(f"{float(x)}" for x in vector) - gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])" - where_clauses = [f"n.{self.dim_field} IS NOT NULL"] - if scope: - where_clauses.append(f'n.memory_type = "{scope}"') - if status: - where_clauses.append(f'n.status = "{status}"') - where_clauses.append(f'n.user_name = "{user_name}"') - - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append(f'n.{key} = "{value}"') - else: - where_clauses.append(f"n.{key} = {value}") - - where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" - - gql = f""" - let a = {gql_vector} - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - {where_clause} - ORDER BY inner_product(n.{self.dim_field}, a) DESC - LIMIT {top_k} - RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score""" - try: - result = self.execute_query(gql) - except Exception as e: - logger.error(f"[search_by_embedding] Query failed: {e}") - return [] - - try: - output = [] - for row in result: - values = row.values() - id_val = values[0].as_string() - score_val = values[1].as_double() - score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score - if threshold is None or score_val >= threshold: - output.append({"id": id_val, "score": score_val}) - return output - except Exception as e: - logger.error(f"[search_by_embedding] Result parse failed: {e}") - return [] - - @timed - def get_by_metadata( - self, filters: list[dict[str, Any]], user_name: str | None = None - ) -> list[str]: - """ - 1. ADD logic: "AND" vs "OR"(support logic combination); - 2. Support nested conditional expressions; - - Retrieve node IDs that match given metadata filters. - Supports exact match. - - Args: - filters: List of filter dicts like: - [ - {"field": "key", "op": "in", "value": ["A", "B"]}, - {"field": "confidence", "op": ">=", "value": 80}, - {"field": "tags", "op": "contains", "value": "AI"}, - ... - ] - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[str]: Node IDs whose metadata match the filter conditions. (AND logic). - - Notes: - - Supports structured querying such as tag/category/importance/time filtering. - - Can be used for faceted recall or prefiltering before embedding rerank. - """ - where_clauses = [] - user_name = user_name if user_name else self.config.user_name - for _i, f in enumerate(filters): - field = f["field"] - op = f.get("op", "=") - value = f["value"] - - escaped_value = self._format_value(value) - - # Build WHERE clause - if op == "=": - where_clauses.append(f"n.{field} = {escaped_value}") - elif op == "in": - where_clauses.append(f"n.{field} IN {escaped_value}") - elif op == "contains": - where_clauses.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0") - elif op == "starts_with": - where_clauses.append(f"n.{field} STARTS WITH {escaped_value}") - elif op == "ends_with": - where_clauses.append(f"n.{field} ENDS WITH {escaped_value}") - elif op in [">", ">=", "<", "<="]: - where_clauses.append(f"n.{field} {op} {escaped_value}") - else: - raise ValueError(f"Unsupported operator: {op}") - - where_clauses.append(f'n.user_name = "{user_name}"') - - where_str = " AND ".join(where_clauses) - gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id" - ids = [] - try: - result = self.execute_query(gql) - ids = [record["id"].value for record in result] - except Exception as e: - logger.error(f"Failed to get metadata: {e}, gql is {gql}") - return ids - - @timed - def get_grouped_counts( - self, - group_fields: list[str], - where_clause: str = "", - params: dict[str, Any] | None = None, - user_name: str | None = None, - ) -> list[dict[str, Any]]: - """ - Count nodes grouped by any fields. - - Args: - group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"] - where_clause (str, optional): Extra WHERE condition. E.g., - "WHERE n.status = 'activated'" - params (dict, optional): Parameters for WHERE clause. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...] - """ - if not group_fields: - raise ValueError("group_fields cannot be empty") - user_name = user_name if user_name else self.config.user_name - # GQL-specific modifications - user_clause = f"n.user_name = '{user_name}'" - if where_clause: - where_clause = where_clause.strip() - if where_clause.upper().startswith("WHERE"): - where_clause += f" AND {user_clause}" - else: - where_clause = f"WHERE {where_clause} AND {user_clause}" - else: - where_clause = f"WHERE {user_clause}" - - # Inline parameters if provided - if params: - for key, value in params.items(): - # Handle different value types appropriately - if isinstance(value, str): - value = f"'{value}'" - where_clause = where_clause.replace(f"${key}", str(value)) - - return_fields = [] - group_by_fields = [] - - for field in group_fields: - alias = field.replace(".", "_") - return_fields.append(f"n.{field} AS {alias}") - group_by_fields.append(alias) - # Full GQL query construction - gql = f""" - MATCH (n /*+ INDEX(idx_memory_user_name) */) - {where_clause} - RETURN {", ".join(return_fields)}, COUNT(n) AS count - """ - result = self.execute_query(gql) # Pure GQL string execution - - output = [] - for record in result: - group_values = {} - for i, field in enumerate(group_fields): - value = record.values()[i].as_string() - group_values[field] = value - count_value = record["count"].value - output.append({**group_values, "count": count_value}) - - return output - - @timed - def clear(self, user_name: str | None = None) -> None: - """ - Clear the entire graph if the target database exists. - - Args: - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - try: - query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n" - self.execute_query(query) - logger.info("Cleared all nodes from database.") - - except Exception as e: - logger.error(f"[ERROR] Failed to clear database: {e}") - - @timed - def export_graph( - self, include_embedding: bool = False, user_name: str | None = None, **kwargs - ) -> dict[str, Any]: - """ - Export all graph nodes and edges in a structured form. - Args: - include_embedding (bool): Whether to include the large embedding field. - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - { - "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ], - "edges": [ { "source": ..., "target": ..., "type": ... }, ... ] - } - """ - user_name = user_name if user_name else self.config.user_name - node_query = "MATCH (n@Memory)" - edge_query = "MATCH (a@Memory)-[r]->(b@Memory)" - node_query += f' WHERE n.user_name = "{user_name}"' - edge_query += f' WHERE r.user_name = "{user_name}"' - - try: - if include_embedding: - return_fields = "n" - else: - return_fields = ",".join( - [ - "n.id AS id", - "n.memory AS memory", - "n.user_name AS user_name", - "n.user_id AS user_id", - "n.session_id AS session_id", - "n.status AS status", - "n.key AS key", - "n.confidence AS confidence", - "n.tags AS tags", - "n.created_at AS created_at", - "n.updated_at AS updated_at", - "n.memory_type AS memory_type", - "n.sources AS sources", - "n.source AS source", - "n.node_type AS node_type", - "n.visibility AS visibility", - "n.usage AS usage", - "n.background AS background", - ] - ) - - full_node_query = f"{node_query} RETURN {return_fields}" - node_result = self.execute_query(full_node_query, timeout=20) - nodes = [] - logger.debug(f"Debugging: {node_result}") - for row in node_result: - if include_embedding: - props = row.values()[0].as_node().get_properties() - else: - props = {k: v.value for k, v in row.items()} - node = self._parse_node(props) - nodes.append(node) - except Exception as e: - raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e - - try: - full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge" - edge_result = self.execute_query(full_edge_query, timeout=20) - edges = [ - { - "source": row.values()[0].value, - "target": row.values()[1].value, - "type": row.values()[2].value, - } - for row in edge_result - ] - except Exception as e: - raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e - - return {"nodes": nodes, "edges": edges} - - @timed - def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None: - """ - Import the entire graph from a serialized dictionary. - - Args: - data: A dictionary containing all nodes and edges to be loaded. - user_name (str, optional): User name for filtering in non-multi-db mode - """ - user_name = user_name if user_name else self.config.user_name - for node in data.get("nodes", []): - try: - id, memory, metadata = _compose_node(node) - metadata["user_name"] = user_name - metadata = self._prepare_node_metadata(metadata) - metadata.update({"id": id, "memory": memory}) - properties = ", ".join( - f"{k}: {self._format_value(v, k)}" for k, v in metadata.items() - ) - node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})" - self.execute_query(node_gql) - except Exception as e: - logger.error(f"Fail to load node: {node}, error: {e}") - - for edge in data.get("edges", []): - try: - source_id, target_id = edge["source"], edge["target"] - edge_type = edge["type"] - props = f'{{user_name: "{user_name}"}}' - edge_gql = f''' - MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) - INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b) - ''' - self.execute_query(edge_gql) - except Exception as e: - logger.error(f"Fail to load edge: {edge}, error: {e}") - - @timed - def get_all_memory_items( - self, scope: str, include_embedding: bool = False, user_name: str | None = None - ) -> (list)[dict]: - """ - Retrieve all memory items of a specific memory_type. - - Args: - scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'. - include_embedding: with/without embedding - user_name (str, optional): User name for filtering in non-multi-db mode - - Returns: - list[dict]: Full list of memory items under this scope. - """ - user_name = user_name if user_name else self.config.user_name - if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}: - raise ValueError(f"Unsupported memory type scope: {scope}") - - where_clause = f"WHERE n.memory_type = '{scope}'" - where_clause += f" AND n.user_name = '{user_name}'" - - return_fields = self._build_return_fields(include_embedding) - - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - {where_clause} - RETURN {return_fields} - LIMIT 100 - """ - nodes = [] - try: - results = self.execute_query(query) - for row in results: - props = {k: v.value for k, v in row.items()} - nodes.append(self._parse_node(props)) - except Exception as e: - logger.error(f"Failed to get memories: {e}") - return nodes - - @timed - def get_structure_optimization_candidates( - self, scope: str, include_embedding: bool = False, user_name: str | None = None - ) -> list[dict]: - """ - Find nodes that are likely candidates for structure optimization: - - Isolated nodes, nodes with empty background, or nodes with exactly one child. - - Plus: the child of any parent node that has exactly one child. - """ - user_name = user_name if user_name else self.config.user_name - where_clause = f''' - n.memory_type = "{scope}" - AND n.status = "activated" - ''' - where_clause += f' AND n.user_name = "{user_name}"' - - return_fields = self._build_return_fields(include_embedding) - return_fields += f", n.{self.dim_field} AS {self.dim_field}" - - query = f""" - MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE {where_clause} - OPTIONAL MATCH (n)-[@PARENT]->(c@Memory) - OPTIONAL MATCH (p@Memory)-[@PARENT]->(n) - WHERE c IS NULL AND p IS NULL - RETURN {return_fields} - """ - - candidates = [] - node_ids = set() - try: - results = self.execute_query(query) - for row in results: - props = {k: v.value for k, v in row.items()} - node = self._parse_node(props) - node_id = node["id"] - if node_id not in node_ids: - candidates.append(node) - node_ids.add(node_id) - except Exception as e: - logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}") - return candidates - - @timed - def drop_database(self) -> None: - """ - Permanently delete the entire database this instance is using. - WARNING: This operation is destructive and cannot be undone. - """ - raise ValueError( - f"Refusing to drop protected database: `{self.db_name}` in " - f"Shared Database Multi-Tenant mode" - ) - - @timed - def detect_conflicts(self) -> list[tuple[str, str]]: - """ - Detect conflicting nodes based on logical or semantic inconsistency. - Returns: - A list of (node_id1, node_id2) tuples that conflict. - """ - raise NotImplementedError - - @timed - # Structure Maintenance - def deduplicate_nodes(self) -> None: - """ - Deduplicate redundant or semantically similar nodes. - This typically involves identifying nodes with identical or near-identical memory. - """ - raise NotImplementedError - - @timed - def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: - """ - Get the ordered context chain starting from a node, following a relationship type. - Args: - id: Starting node ID. - type: Relationship type to follow (e.g., 'FOLLOWS'). - Returns: - List of ordered node IDs in the chain. - """ - raise NotImplementedError - - @timed - def get_neighbors( - self, id: str, type: str, direction: Literal["in", "out", "both"] = "out" - ) -> list[str]: - """ - Get connected node IDs in a specific direction and relationship type. - Args: - id: Source node ID. - type: Relationship type. - direction: Edge direction to follow ('out', 'in', or 'both'). - Returns: - List of neighboring node IDs. - """ - raise NotImplementedError - - @timed - def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]: - """ - Get the path of nodes from source to target within a limited depth. - Args: - source_id: Starting node ID. - target_id: Target node ID. - max_depth: Maximum path length to traverse. - Returns: - Ordered list of node IDs along the path. - """ - raise NotImplementedError - - @timed - def merge_nodes(self, id1: str, id2: str) -> str: - """ - Merge two similar or duplicate nodes into one. - Args: - id1: First node ID. - id2: Second node ID. - Returns: - ID of the resulting merged node. - """ - raise NotImplementedError - - @classmethod - def _ensure_space_exists(cls, tmp_client, cfg): - """Lightweight check to ensure target graph (space) exists.""" - db_name = getattr(cfg, "space", None) - if not db_name: - logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.") - return - - try: - res = tmp_client.execute("SHOW GRAPHS") - existing = {row.values()[0].as_string() for row in res} - if db_name not in existing: - tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type") - logger.info(f"โœ… Graph `{db_name}` created before session binding.") - else: - logger.debug(f"Graph `{db_name}` already exists.") - except Exception: - logger.exception("[NebulaGraphDBSync] Failed to ensure space exists") - - @timed - def _ensure_database_exists(self): - graph_type_name = "MemOSBgeM3Type" - - check_type_query = "SHOW GRAPH TYPES" - result = self.execute_query(check_type_query, auto_set_db=False) - - type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result) - - if not type_exists: - create_tag = f""" - CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{ - NODE Memory (:MemoryTag {{ - id STRING, - memory STRING, - user_name STRING, - user_id STRING, - session_id STRING, - status STRING, - key STRING, - confidence FLOAT, - tags LIST, - created_at STRING, - updated_at STRING, - memory_type STRING, - sources LIST, - source STRING, - node_type STRING, - visibility STRING, - usage LIST, - background STRING, - {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>, - PRIMARY KEY(id) - }}), - EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory), - EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory) - }} - """ - self.execute_query(create_tag, auto_set_db=False) - else: - describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name}" - desc_result = self.execute_query(describe_query, auto_set_db=False) - - memory_fields = [] - for row in desc_result: - field_name = row.values()[0].as_string() - memory_fields.append(field_name) - - if self.dim_field not in memory_fields: - alter_query = f""" - ALTER GRAPH TYPE {graph_type_name} {{ - ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }} - }} - """ - self.execute_query(alter_query, auto_set_db=False) - logger.info(f"โœ… Add new vector search {self.dim_field} to {graph_type_name}") - else: - logger.info(f"โœ… Graph Type {graph_type_name} already include {self.dim_field}") - - create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}" - try: - self.execute_query(create_graph, auto_set_db=False) - logger.info(f"โœ… Graph ``{self.db_name}`` is now the working graph.") - except Exception as e: - logger.error(f"โŒ Failed to create tag: {e} trace: {traceback.format_exc()}") - - @timed - def _create_vector_index( - self, - label: str = "Memory", - vector_property: str = "embedding", - dimensions: int = 3072, - index_name: str = "memory_vector_index", - ) -> None: - """ - Create a vector index for the specified property in the label. - """ - if str(dimensions) == str(self.default_memory_dimension): - index_name = f"idx_{vector_property}" - vector_name = vector_property - else: - index_name = f"idx_{vector_property}_{dimensions}" - vector_name = f"{vector_property}_{dimensions}" - - create_vector_index = f""" - CREATE VECTOR INDEX IF NOT EXISTS {index_name} - ON NODE {label}::{vector_name} - OPTIONS {{ - DIM: {dimensions}, - METRIC: IP, - TYPE: IVF, - NLIST: 100, - TRAINSIZE: 1000 - }} - FOR `{self.db_name}` - """ - self.execute_query(create_vector_index) - logger.info( - f"โœ… Ensure {label}::{vector_property} vector index {index_name} " - f"exists (DIM={dimensions})" - ) - - @timed - def _create_basic_property_indexes(self) -> None: - """ - Create standard B-tree indexes on status, memory_type, created_at - and updated_at fields. - Create standard B-tree indexes on user_name when use Shared Database - Multi-Tenant Mode. - """ - fields = [ - "status", - "memory_type", - "created_at", - "updated_at", - "user_name", - ] - - for field in fields: - index_name = f"idx_memory_{field}" - gql = f""" - CREATE INDEX IF NOT EXISTS {index_name} ON NODE Memory({field}) - FOR `{self.db_name}` - """ - try: - self.execute_query(gql) - logger.info(f"โœ… Created index: {index_name} on field {field}") - except Exception as e: - logger.error( - f"โŒ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}" - ) - - @timed - def _index_exists(self, index_name: str) -> bool: - """ - Check if an index with the given name exists. - """ - """ - Check if a vector index with the given name exists in NebulaGraph. - - Args: - index_name (str): The name of the index to check. - - Returns: - bool: True if the index exists, False otherwise. - """ - query = "SHOW VECTOR INDEXES" - try: - result = self.execute_query(query) - return any(row.values()[0].as_string() == index_name for row in result) - except Exception as e: - logger.error(f"[Nebula] Failed to check index existence: {e}") - return False - - @timed - def _parse_value(self, value: Any) -> Any: - """turn Nebula ValueWrapper to Python type""" - from nebulagraph_python.value_wrapper import ValueWrapper - - if value is None or (hasattr(value, "is_null") and value.is_null()): - return None - try: - prim = value.cast_primitive() if isinstance(value, ValueWrapper) else value - except Exception as e: - logger.warning(f"Error when decode Nebula ValueWrapper: {e}") - prim = value.cast() if isinstance(value, ValueWrapper) else value - - if isinstance(prim, ValueWrapper): - return self._parse_value(prim) - if isinstance(prim, list): - return [self._parse_value(v) for v in prim] - if type(prim).__name__ == "NVector": - return list(prim.values) - - return prim # already a Python primitive - - def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]: - parsed = {k: self._parse_value(v) for k, v in props.items()} - - for tf in ("created_at", "updated_at"): - if tf in parsed and parsed[tf] is not None: - parsed[tf] = _normalize_datetime(parsed[tf]) - - node_id = parsed.pop("id") - memory = parsed.pop("memory", "") - parsed.pop("user_name", None) - metadata = parsed - metadata["type"] = metadata.pop("node_type") - - if self.dim_field in metadata: - metadata["embedding"] = metadata.pop(self.dim_field) - - return {"id": node_id, "memory": memory, "metadata": metadata} - - @timed - def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """ - Ensure metadata has proper datetime fields and normalized types. - - - Fill `created_at` and `updated_at` if missing (in ISO 8601 format). - - Convert embedding to list of float if present. - """ - now = datetime.utcnow().isoformat() - metadata["node_type"] = metadata.pop("type") - - # Fill timestamps if missing - metadata.setdefault("created_at", now) - metadata.setdefault("updated_at", now) - - # Normalize embedding type - embedding = metadata.get("embedding") - if embedding and isinstance(embedding, list): - metadata.pop("embedding") - metadata[self.dim_field] = _normalize([float(x) for x in embedding]) - - return metadata - - @timed - def _format_value(self, val: Any, key: str = "") -> str: - from nebulagraph_python.py_data_types import NVector - - # None - if val is None: - return "NULL" - # bool - if isinstance(val, bool): - return "true" if val else "false" - # str - if isinstance(val, str): - return f'"{_escape_str(val)}"' - # num - elif isinstance(val, (int | float)): - return str(val) - # time - elif isinstance(val, datetime): - return f'datetime("{val.isoformat()}")' - # list - elif isinstance(val, list): - if key == self.dim_field: - dim = len(val) - joined = ",".join(str(float(x)) for x in val) - return f"VECTOR<{dim}, FLOAT>([{joined}])" - else: - return f"[{', '.join(self._format_value(v) for v in val)}]" - # NVector - elif isinstance(val, NVector): - if key == self.dim_field: - dim = len(val) - joined = ",".join(str(float(x)) for x in val) - return f"VECTOR<{dim}, FLOAT>([{joined}])" - else: - logger.warning("Invalid NVector") - # dict - if isinstance(val, dict): - j = json.dumps(val, ensure_ascii=False, separators=(",", ":")) - return f'"{_escape_str(j)}"' - else: - return f'"{_escape_str(str(val))}"' - - @timed - def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]: - """ - Filter and validate metadata dictionary against the Memory node schema. - - Removes keys not in schema. - - Warns if required fields are missing. - """ - - dim_fields = {self.dim_field} - - allowed_fields = self.common_fields | dim_fields - - missing_fields = allowed_fields - metadata.keys() - if missing_fields: - logger.info(f"Metadata missing required fields: {sorted(missing_fields)}") - - filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields} - - return filtered_metadata - - def _build_return_fields(self, include_embedding: bool = False) -> str: - fields = set(self.base_fields) - if include_embedding: - fields.add(self.dim_field) - return ", ".join(f"n.{f} AS {f}" for f in fields) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 4d88844df..f61d981d8 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -441,7 +441,7 @@ def remove_oldest_memory( ) user_name = user_name if user_name else self._get_config_value("user_name") - # Use actual OFFSET logic, consistent with nebular.py + # Use actual OFFSET logic for deterministic pruning # First find IDs to delete, then delete them select_query = f""" SELECT id FROM "{self.db_name}_graph"."Memory" @@ -3539,7 +3539,7 @@ def get_neighbors_by_tag_ccl( user_name = user_name if user_name else self._get_config_value("user_name") - # Build query conditions; keep consistent with nebular.py + # Build query conditions shared with other graph backends where_clauses = [ 'n.status = "activated"', 'NOT (n.node_type = "reasoning")', @@ -3592,7 +3592,7 @@ def get_neighbors_by_tag_ccl( # Add overlap_count result_fields.append("overlap_count agtype") result_fields_str = ", ".join(result_fields) - # Use Cypher query; keep consistent with nebular.py + # Use Cypher query to keep the graph query path aligned query = f""" SELECT * FROM ( SELECT * FROM cypher('{self.db_name}_graph', $$ diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py index ec431c253..671190e6f 100644 --- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py +++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py @@ -47,13 +47,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: graph_db_backend_map = { "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), "neo4j": APIConfig.get_neo4j_config(user_id=user_id), - "nebular": APIConfig.get_nebular_config(user_id=user_id), "polardb": APIConfig.get_polardb_config(user_id=user_id), "postgres": APIConfig.get_postgres_config(user_id=user_id), } # Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars - graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower() + graph_db_backend = os.getenv( + "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community") + ).lower() return GraphDBConfigFactory.model_validate( { "backend": graph_db_backend,