diff --git a/README.md b/README.md
index a7b05d683..1c7d5bd93 100644
--- a/README.md
+++ b/README.md
@@ -75,7 +75,7 @@
- [**72% lower token usage**](https://x.com/MemOS_dev/status/2020854044583924111) โ intelligent memory retrieval instead of loading full chat history
- [**Multi-agent memory sharing**](https://x.com/MemOS_dev/status/2020538135487062094) โ multi-instance agents share memory via same user_id, automatic context handoff
-Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/)
+Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/)
Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin)
### ๐ง Local Plugin โ 100% On-Device Memory
@@ -84,7 +84,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- **Hybrid search + task & skill evolution** โ FTS5 + vector search, auto task summarization, reusable skills that self-upgrade
- **Multi-agent collaboration + Memory Viewer** โ memory isolation, skill sharing, full web dashboard with 7 management pages
- ๐ [Homepage](https://memos-claw.openmem.net) ยท
+ ๐ [Homepage](https://memos-claw.openmem.net) ยท
๐ [Documentation](https://memos-claw.openmem.net/docs/index.html) ยท ๐ฆ [NPM](https://www.npmjs.com/package/@memtensor/memos-local-openclaw-plugin)
## ๐ MemOS: Memory Operating System for AI Agents
@@ -104,10 +104,10 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
### News
-- **2026-03-08** ยท ๐ฆ **MemOS OpenClaw Plugin โ Cloud & Local**
+- **2026-03-08** ยท ๐ฆ **MemOS OpenClaw Plugin โ Cloud & Local**
Official OpenClaw memory plugins launched. **Cloud Plugin**: hosted memory service with 72% lower token usage and multi-agent memory sharing ([MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin)). **Local Plugin** (`v1.0.0`): 100% on-device memory with persistent SQLite, hybrid search (FTS5 + vector), task summarization & skill evolution, multi-agent collaboration, and a full Memory Viewer dashboard.
-- **2025-12-24** ยท ๐ **MemOS v2.0: Stardust (ๆๅฐ) Release**
+- **2025-12-24** ยท ๐ **MemOS v2.0: Stardust (ๆๅฐ) Release**
Comprehensive KB (doc/URL parsing + cross-project sharing), memory feedback & precise deletion, multi-modal memory (images/charts), tool memory for agent planning, Redis Streams scheduling + DB optimizations, streaming/non-streaming chat, MCP upgrade, and lightweight quick/full deployment.
โจ New Features
@@ -155,7 +155,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- **2025-08-07** ยท ๐ **MemOS v1.0.0 (MemCube) Release**
- First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, NebulaGraph support, improved search capabilities, and the official Playground launch.
+ First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, improved search capabilities, and the official Playground launch.
โจ New Features
@@ -176,7 +176,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
**Plaintext Memory**
- Integrated internet search with Bocha.
- - Added support for Nebula database.
+ - Expanded graph database support.
- Added contextual understanding for the tree-structured plaintext memory search interface.
@@ -188,7 +188,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- Fixed the concat_cache method.
**Plaintext Memory**
- - Fixed Nebula search-related issues.
+ - Fixed graph search-related issues.
diff --git a/docker/.env.example b/docker/.env.example
index 3674cd69b..08f9cf1db 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -127,7 +127,7 @@ MEMSCHEDULER_USE_REDIS_QUEUE=false
## Graph / vector stores
# Neo4j database selection mode
-NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | nebular | polardb
+NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | polardb | postgres
# Neo4j database url
NEO4J_URI=bolt://localhost:7687 # required when backend=neo4j*
# Neo4j database user
diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py
index 6dbe202c2..d14b6d687 100644
--- a/examples/mem_agent/deepsearch_example.py
+++ b/examples/mem_agent/deepsearch_example.py
@@ -47,7 +47,7 @@ def build_minimal_components():
# Build component configurations using APIConfig methods (like config_builders.py)
- # Graph DB configuration - using APIConfig.get_nebular_config()
+ # Graph DB configuration - using APIConfig graph DB helpers
graph_db_backend = os.getenv("NEO4J_BACKEND", "polardb").lower()
graph_db_backend_map = {
"polardb": APIConfig.get_polardb_config(),
diff --git a/src/memos/api/config.py b/src/memos/api/config.py
index 87f1efd8e..2e8bae57c 100644
--- a/src/memos/api/config.py
+++ b/src/memos/api/config.py
@@ -741,21 +741,6 @@ def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]:
"embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
}
- @staticmethod
- def get_nebular_config(user_id: str | None = None) -> dict[str, Any]:
- """Get Nebular configuration."""
- return {
- "uri": json.loads(os.getenv("NEBULAR_HOSTS", '["localhost"]')),
- "user": os.getenv("NEBULAR_USER", "root"),
- "password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
- "space": os.getenv("NEBULAR_SPACE", "shared-tree-textual-memory"),
- "user_name": f"memos{user_id.replace('-', '')}",
- "use_multi_db": False,
- "auto_create": True,
- "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
- }
-
- @staticmethod
def get_milvus_config():
return {
"collection_name": [
@@ -1103,7 +1088,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id)
neo4j_config = APIConfig.get_neo4j_config(user_id)
- nebular_config = APIConfig.get_nebular_config(user_id)
polardb_config = APIConfig.get_polardb_config(user_id)
internet_config = (
APIConfig.get_internet_config()
@@ -1114,7 +1098,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
graph_db_backend_map = {
"neo4j-community": neo4j_community_config,
"neo4j": neo4j_config,
- "nebular": nebular_config,
"polardb": polardb_config,
"postgres": postgres_config,
}
@@ -1144,9 +1127,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower()
== "true",
"memory_size": {
- "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)),
- "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)),
- "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)),
+ "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)),
+ "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)),
+ "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)),
},
"search_strategy": {
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
@@ -1169,7 +1152,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
}
)
else:
- raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}")
+ raise ValueError(f"Invalid graph DB backend: {graph_db_backend}")
default_mem_cube = GeneralMemCube(default_cube_config)
return default_config, default_mem_cube
@@ -1188,13 +1171,11 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
openai_config = APIConfig.get_openai_config()
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default")
neo4j_config = APIConfig.get_neo4j_config(user_id="default")
- nebular_config = APIConfig.get_nebular_config(user_id="default")
polardb_config = APIConfig.get_polardb_config(user_id="default")
postgres_config = APIConfig.get_postgres_config(user_id="default")
graph_db_backend_map = {
"neo4j-community": neo4j_community_config,
"neo4j": neo4j_config,
- "nebular": nebular_config,
"polardb": polardb_config,
"postgres": postgres_config,
}
@@ -1227,9 +1208,9 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
== "true",
"internet_retriever": internet_config,
"memory_size": {
- "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)),
- "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)),
- "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)),
+ "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)),
+ "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)),
+ "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)),
},
"search_strategy": {
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
@@ -1253,4 +1234,4 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
}
)
else:
- raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}")
+ raise ValueError(f"Invalid graph DB backend: {graph_db_backend}")
diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py
index 5655bf1e5..d29429fc9 100644
--- a/src/memos/api/handlers/config_builders.py
+++ b/src/memos/api/handlers/config_builders.py
@@ -39,13 +39,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
graph_db_backend_map = {
"neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id),
"neo4j": APIConfig.get_neo4j_config(user_id=user_id),
- "nebular": APIConfig.get_nebular_config(user_id=user_id),
"polardb": APIConfig.get_polardb_config(user_id=user_id),
"postgres": APIConfig.get_postgres_config(user_id=user_id),
}
# Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
- graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower()
+ graph_db_backend = os.getenv(
+ "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")
+ ).lower()
return GraphDBConfigFactory.model_validate(
{
"backend": graph_db_backend,
diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py
index 5900d2357..98de09812 100644
--- a/src/memos/configs/graph_db.py
+++ b/src/memos/configs/graph_db.py
@@ -103,57 +103,6 @@ def validate_community(self):
return self
-class NebulaGraphDBConfig(BaseGraphDBConfig):
- """
- NebulaGraph-specific configuration.
-
- Key concepts:
- - `space`: Equivalent to a database or namespace. All tag/edge/schema live within a space.
- - `user_name`: Used for logical tenant isolation if needed.
- - `auto_create`: Whether to automatically create the target space if it does not exist.
-
- Example:
- ---
- hosts = ["127.0.0.1:9669"]
- user = "root"
- password = "nebula"
- space = "shared_graph"
- user_name = "alice"
- """
-
- space: str = Field(
- ..., description="The name of the target NebulaGraph space (like a database)"
- )
- user_name: str | None = Field(
- default=None,
- description="Logical user or tenant ID for data isolation (optional, used in metadata tagging)",
- )
- auto_create: bool = Field(
- default=False,
- description="Whether to auto-create the space if it does not exist",
- )
- use_multi_db: bool = Field(
- default=True,
- description=(
- "If True: use Neo4j's multi-database feature for physical isolation; "
- "each user typically gets a separate database. "
- "If False: use a single shared database with logical isolation by user_name."
- ),
- )
- max_client: int = Field(
- default=1000,
- description=("max_client"),
- )
- embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding")
-
- @model_validator(mode="after")
- def validate_config(self):
- """Validate config."""
- if not self.space:
- raise ValueError("`space` must be provided")
- return self
-
-
class PolarDBGraphDBConfig(BaseConfig):
"""
PolarDB-specific configuration.
@@ -299,7 +248,6 @@ class GraphDBConfigFactory(BaseModel):
backend_to_class: ClassVar[dict[str, Any]] = {
"neo4j": Neo4jGraphDBConfig,
"neo4j-community": Neo4jCommunityGraphDBConfig,
- "nebular": NebulaGraphDBConfig,
"polardb": PolarDBGraphDBConfig,
"postgres": PostgresGraphDBConfig,
}
diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py
index c207e3190..93b5971ec 100644
--- a/src/memos/graph_dbs/factory.py
+++ b/src/memos/graph_dbs/factory.py
@@ -2,7 +2,6 @@
from memos.configs.graph_db import GraphDBConfigFactory
from memos.graph_dbs.base import BaseGraphDB
-from memos.graph_dbs.nebular import NebulaGraphDB
from memos.graph_dbs.neo4j import Neo4jGraphDB
from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB
from memos.graph_dbs.polardb import PolarDBGraphDB
@@ -15,7 +14,6 @@ class GraphStoreFactory(BaseGraphDB):
backend_to_class: ClassVar[dict[str, Any]] = {
"neo4j": Neo4jGraphDB,
"neo4j-community": Neo4jCommunityGraphDB,
- "nebular": NebulaGraphDB,
"polardb": PolarDBGraphDB,
"postgres": PostgresGraphDB,
}
diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py
deleted file mode 100644
index 428d6d09e..000000000
--- a/src/memos/graph_dbs/nebular.py
+++ /dev/null
@@ -1,1794 +0,0 @@
-import json
-import traceback
-
-from contextlib import suppress
-from datetime import datetime
-from threading import Lock
-from typing import TYPE_CHECKING, Any, ClassVar, Literal
-
-import numpy as np
-
-from memos.configs.graph_db import NebulaGraphDBConfig
-from memos.dependency import require_python_package
-from memos.graph_dbs.base import BaseGraphDB
-from memos.log import get_logger
-from memos.utils import timed
-
-
-if TYPE_CHECKING:
- from nebulagraph_python import (
- NebulaClient,
- )
-
-
-logger = get_logger(__name__)
-
-
-_TRANSIENT_ERR_KEYS = (
- "Session not found",
- "Connection not established",
- "timeout",
- "deadline exceeded",
- "Broken pipe",
- "EOFError",
- "socket closed",
- "connection reset",
- "connection refused",
-)
-
-
-@timed
-def _normalize(vec: list[float]) -> list[float]:
- v = np.asarray(vec, dtype=np.float32)
- norm = np.linalg.norm(v)
- return (v / (norm if norm else 1.0)).tolist()
-
-
-@timed
-def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
- node_id = item["id"]
- memory = item["memory"]
- metadata = item.get("metadata", {})
- return node_id, memory, metadata
-
-
-@timed
-def _escape_str(value: str) -> str:
- out = []
- for ch in value:
- code = ord(ch)
- if ch == "\\":
- out.append("\\\\")
- elif ch == '"':
- out.append('\\"')
- elif ch == "\n":
- out.append("\\n")
- elif ch == "\r":
- out.append("\\r")
- elif ch == "\t":
- out.append("\\t")
- elif ch == "\b":
- out.append("\\b")
- elif ch == "\f":
- out.append("\\f")
- elif code < 0x20 or code in (0x2028, 0x2029):
- out.append(f"\\u{code:04x}")
- else:
- out.append(ch)
- return "".join(out)
-
-
-@timed
-def _format_datetime(value: str | datetime) -> str:
- """Ensure datetime is in ISO 8601 format string."""
- if isinstance(value, datetime):
- return value.isoformat()
- return str(value)
-
-
-@timed
-def _normalize_datetime(val):
- """
- Normalize datetime to ISO 8601 UTC string with +00:00.
- - If val is datetime object -> keep isoformat() (Neo4j)
- - If val is string without timezone -> append +00:00 (Nebula)
- - Otherwise just str()
- """
- if hasattr(val, "isoformat"):
- return val.isoformat()
- if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")):
- return val + "+08:00"
- return str(val)
-
-
-class NebulaGraphDB(BaseGraphDB):
- """
- NebulaGraph-based implementation of a graph memory store.
- """
-
- # ====== shared pool cache & refcount ======
- # These are process-local; in a multi-process model each process will
- # have its own cache.
- _CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
- _CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
- _CLIENT_LOCK: ClassVar[Lock] = Lock()
- _CLIENT_INIT_DONE: ClassVar[set[str]] = set()
-
- @staticmethod
- def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
- hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None)
- if isinstance(hosts, str):
- return [hosts]
- return list(hosts or [])
-
- @staticmethod
- def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
- hosts = NebulaGraphDB._get_hosts_from_cfg(cfg)
- return "|".join(
- [
- "nebula-sync",
- ",".join(hosts),
- str(getattr(cfg, "user", "")),
- str(getattr(cfg, "space", "")),
- ]
- )
-
- @classmethod
- def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB":
- tmp = object.__new__(NebulaGraphDB)
- tmp.config = cfg
- tmp.db_name = cfg.space
- tmp.user_name = None
- tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072)
- tmp.default_memory_dimension = 3072
- tmp.common_fields = {
- "id",
- "memory",
- "user_name",
- "user_id",
- "session_id",
- "status",
- "key",
- "confidence",
- "tags",
- "created_at",
- "updated_at",
- "memory_type",
- "sources",
- "source",
- "node_type",
- "visibility",
- "usage",
- "background",
- }
- tmp.base_fields = set(tmp.common_fields) - {"usage"}
- tmp.heavy_fields = {"usage"}
- tmp.dim_field = (
- f"embedding_{tmp.embedding_dimension}"
- if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension)
- else "embedding"
- )
- tmp.system_db_name = cfg.space
- tmp._client = client
- tmp._owns_client = False
- return tmp
-
- @classmethod
- def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]:
- from nebulagraph_python import (
- ConnectionConfig,
- NebulaClient,
- SessionConfig,
- SessionPoolConfig,
- )
-
- key = cls._make_client_key(cfg)
- with cls._CLIENT_LOCK:
- client = cls._CLIENT_CACHE.get(key)
- if client is None:
- # Connection setting
-
- tmp_client = NebulaClient(
- hosts=cfg.uri,
- username=cfg.user,
- password=cfg.password,
- session_config=SessionConfig(graph=None),
- session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000),
- )
- try:
- cls._ensure_space_exists(tmp_client, cfg)
- finally:
- tmp_client.close()
-
- conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
- if conn_conf is None:
- conn_conf = ConnectionConfig.from_defults(
- cls._get_hosts_from_cfg(cfg),
- getattr(cfg, "ssl_param", None),
- )
-
- sess_conf = SessionConfig(graph=getattr(cfg, "space", None))
- pool_conf = SessionPoolConfig(
- size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000
- )
-
- client = NebulaClient(
- hosts=conn_conf.hosts,
- username=cfg.user,
- password=cfg.password,
- conn_config=conn_conf,
- session_config=sess_conf,
- session_pool_config=pool_conf,
- )
- cls._CLIENT_CACHE[key] = client
- cls._CLIENT_REFCOUNT[key] = 0
- logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
-
- cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
-
- if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
- try:
- pass
- finally:
- pass
-
- if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
- with cls._CLIENT_LOCK:
- if key not in cls._CLIENT_INIT_DONE:
- admin = cls._bootstrap_admin(cfg, client)
- try:
- admin._ensure_database_exists()
- admin._create_basic_property_indexes()
- admin._create_vector_index(
- dimensions=int(
- admin.embedding_dimension or admin.default_memory_dimension
- ),
- )
- cls._CLIENT_INIT_DONE.add(key)
- logger.info("[NebulaGraphDBSync] One-time init done")
- except Exception:
- logger.exception("[NebulaGraphDBSync] One-time init failed")
-
- return key, client
-
- def _refresh_client(self):
- """
- refresh NebulaClient:
- """
- old_key = getattr(self, "_client_key", None)
- if not old_key:
- return
-
- cls = self.__class__
- with cls._CLIENT_LOCK:
- try:
- if old_key in cls._CLIENT_CACHE:
- try:
- cls._CLIENT_CACHE[old_key].close()
- except Exception as e:
- logger.warning(f"[refresh_client] close old client error: {e}")
- finally:
- cls._CLIENT_CACHE.pop(old_key, None)
- finally:
- cls._CLIENT_REFCOUNT[old_key] = 0
-
- new_key, new_client = cls._get_or_create_shared_client(self.config)
- self._client_key = new_key
- self._client = new_client
- logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}")
-
- @classmethod
- def _release_shared_client(cls, key: str):
- with cls._CLIENT_LOCK:
- if key not in cls._CLIENT_CACHE:
- return
- cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1)
- if cls._CLIENT_REFCOUNT[key] == 0:
- try:
- cls._CLIENT_CACHE[key].close()
- except Exception as e:
- logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}")
- finally:
- cls._CLIENT_CACHE.pop(key, None)
- cls._CLIENT_REFCOUNT.pop(key, None)
- logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}")
-
- @classmethod
- def close_all_shared_clients(cls):
- with cls._CLIENT_LOCK:
- for key, client in list(cls._CLIENT_CACHE.items()):
- try:
- client.close()
- except Exception as e:
- logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}")
- finally:
- logger.info(f"[NebulaGraphDBSync] Closed client key={key}")
- cls._CLIENT_CACHE.clear()
- cls._CLIENT_REFCOUNT.clear()
-
- @require_python_package(
- import_name="nebulagraph_python",
- install_command="pip install nebulagraph-python>=5.1.1",
- install_link=".....",
- )
- def __init__(self, config: NebulaGraphDBConfig):
- """
- NebulaGraph DB client initialization.
-
- Required config attributes:
- - hosts: list[str] like ["host1:port", "host2:port"]
- - user: str
- - password: str
- - db_name: str (optional for basic commands)
-
- Example config:
- {
- "hosts": ["xxx.xx.xx.xxx:xxxx"],
- "user": "root",
- "password": "nebula",
- "space": "test"
- }
- """
-
- assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED"
- self.config = config
- self.db_name = config.space
- self.user_name = config.user_name
- self.embedding_dimension = config.embedding_dimension
- self.default_memory_dimension = 3072
- self.common_fields = {
- "id",
- "memory",
- "user_name",
- "user_id",
- "session_id",
- "status",
- "key",
- "confidence",
- "tags",
- "created_at",
- "updated_at",
- "memory_type",
- "sources",
- "source",
- "node_type",
- "visibility",
- "usage",
- "background",
- }
- self.base_fields = set(self.common_fields) - {"usage"}
- self.heavy_fields = {"usage"}
- self.dim_field = (
- f"embedding_{self.embedding_dimension}"
- if (str(self.embedding_dimension) != str(self.default_memory_dimension))
- else "embedding"
- )
- self.system_db_name = config.space
-
- # ---- NEW: pool acquisition strategy
- # Get or create a shared pool from the class-level cache
- self._client_key, self._client = self._get_or_create_shared_client(config)
- self._owns_client = True
-
- logger.info("Connected to NebulaGraph successfully.")
-
- @timed
- def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
- def _wrap_use_db(q: str) -> str:
- if auto_set_db and self.db_name:
- return f"USE `{self.db_name}`\n{q}"
- return q
-
- try:
- return self._client.execute(_wrap_use_db(gql), timeout=timeout)
-
- except Exception as e:
- emsg = str(e)
- if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS):
- logger.warning(f"[execute_query] {e!s} โ refreshing session pool and retry once...")
- try:
- self._refresh_client()
- return self._client.execute(_wrap_use_db(gql), timeout=timeout)
- except Exception:
- logger.exception("[execute_query] retry after refresh failed")
- raise
- raise
-
- @timed
- def close(self):
- """
- Close the connection resource if this instance owns it.
-
- - If pool was injected (`shared_pool`), do nothing.
- - If pool was acquired via shared cache, decrement refcount and close
- when the last owner releases it.
- """
- if not self._owns_client:
- logger.debug("[NebulaGraphDBSync] close() skipped (injected client).")
- return
- if self._client_key:
- self._release_shared_client(self._client_key)
- self._client_key = None
- self._client = None
-
- # NOTE: __del__ is best-effort; do not rely on GC order.
- def __del__(self):
- with suppress(Exception):
- self.close()
-
- @timed
- def create_index(
- self,
- label: str = "Memory",
- vector_property: str = "embedding",
- dimensions: int = 3072,
- index_name: str = "memory_vector_index",
- ) -> None:
- # Create vector index
- self._create_vector_index(label, vector_property, dimensions, index_name)
- # Create indexes
- self._create_basic_property_indexes()
-
- @timed
- def remove_oldest_memory(
- self, memory_type: str, keep_latest: int, user_name: str | None = None
- ) -> None:
- """
- Remove all WorkingMemory nodes except the latest `keep_latest` entries.
-
- Args:
- memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
- keep_latest (int): Number of latest WorkingMemory entries to keep.
- user_name(str): optional user_name.
- """
- try:
- user_name = user_name if user_name else self.config.user_name
- optional_condition = f"AND n.user_name = '{user_name}'"
- count = self.count_nodes(memory_type, user_name)
- if count > keep_latest:
- delete_query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE n.memory_type = '{memory_type}'
- {optional_condition}
- ORDER BY n.updated_at DESC
- OFFSET {int(keep_latest)}
- DETACH DELETE n
- """
- self.execute_query(delete_query)
- except Exception as e:
- logger.warning(f"Delete old mem error: {e}")
-
- @timed
- def add_node(
- self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None
- ) -> None:
- """
- Insert or update a Memory node in NebulaGraph.
- """
- metadata["user_name"] = user_name if user_name else self.config.user_name
- now = datetime.utcnow()
- metadata = metadata.copy()
- metadata.setdefault("created_at", now)
- metadata.setdefault("updated_at", now)
- metadata["node_type"] = metadata.pop("type")
- metadata["id"] = id
- metadata["memory"] = memory
-
- if "embedding" in metadata and isinstance(metadata["embedding"], list):
- assert len(metadata["embedding"]) == self.embedding_dimension, (
- f"input embedding dimension must equal to {self.embedding_dimension}"
- )
- embedding = metadata.pop("embedding")
- metadata[self.dim_field] = _normalize(embedding)
-
- metadata = self._metadata_filter(metadata)
- properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
- gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
-
- try:
- self.execute_query(gql)
- logger.info("insert success")
- except Exception as e:
- logger.error(
- f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}"
- )
-
- @timed
- def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
- user_name = user_name if user_name else self.config.user_name
- filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"'
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE {filter_clause}
- RETURN n.id AS id
- LIMIT 1
- """
-
- try:
- result = self.execute_query(query)
- return result.size == 0
- except Exception as e:
- logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True)
- raise
-
- @timed
- def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None:
- """
- Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present.
- """
- user_name = user_name if user_name else self.config.user_name
- fields = fields.copy()
- set_clauses = []
- for k, v in fields.items():
- set_clauses.append(f"n.{k} = {self._format_value(v, k)}")
-
- set_clause_str = ",\n ".join(set_clauses)
-
- query = f"""
- MATCH (n@Memory {{id: "{id}"}})
- """
- query += f'WHERE n.user_name = "{user_name}"'
-
- query += f"\nSET {set_clause_str}"
- self.execute_query(query)
-
- @timed
- def delete_node(self, id: str, user_name: str | None = None) -> None:
- """
- Delete a node from the graph.
- Args:
- id: Node identifier to delete.
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)}
- DETACH DELETE n
- """
- self.execute_query(query)
-
- @timed
- def add_edge(self, source_id: str, target_id: str, type: str, user_name: str | None = None):
- """
- Create an edge from source node to target node.
- Args:
- source_id: ID of the source node.
- target_id: ID of the target node.
- type: Relationship type (e.g., 'RELATE_TO', 'PARENT').
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- if not source_id or not target_id:
- raise ValueError("[add_edge] source_id and target_id must be provided")
- user_name = user_name if user_name else self.config.user_name
- props = ""
- props = f'{{user_name: "{user_name}"}}'
- insert_stmt = f'''
- MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
- INSERT (a) -[e@{type} {props}]-> (b)
- '''
- try:
- self.execute_query(insert_stmt)
- except Exception as e:
- logger.error(f"Failed to insert edge: {e}", exc_info=True)
-
- @timed
- def delete_edge(
- self, source_id: str, target_id: str, type: str, user_name: str | None = None
- ) -> None:
- """
- Delete a specific edge between two nodes.
- Args:
- source_id: ID of the source node.
- target_id: ID of the target node.
- type: Relationship type to remove.
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (a@Memory) -[r@{type}]-> (b@Memory)
- WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)}
- """
-
- query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}"
- query += "\nDELETE r"
- self.execute_query(query)
-
- @timed
- def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int:
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (n@Memory)
- WHERE n.memory_type = "{memory_type}"
- """
- query += f"\nAND n.user_name = '{user_name}'"
- query += "\nRETURN COUNT(n) AS count"
-
- try:
- result = self.execute_query(query)
- return result.one_or_none()["count"].value
- except Exception as e:
- logger.error(f"[get_memory_count] Failed: {e}")
- return -1
-
- @timed
- def count_nodes(self, scope: str, user_name: str | None = None) -> int:
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (n@Memory)
- WHERE n.memory_type = "{scope}"
- """
- query += f"\nAND n.user_name = '{user_name}'"
- query += "\nRETURN count(n) AS count"
-
- result = self.execute_query(query)
- return result.one_or_none()["count"].value
-
- @timed
- def edge_exists(
- self,
- source_id: str,
- target_id: str,
- type: str = "ANY",
- direction: str = "OUTGOING",
- user_name: str | None = None,
- ) -> bool:
- """
- Check if an edge exists between two nodes.
- Args:
- source_id: ID of the source node.
- target_id: ID of the target node.
- type: Relationship type. Use "ANY" to match any relationship type.
- direction: Direction of the edge.
- Use "OUTGOING" (default), "INCOMING", or "ANY".
- user_name (str, optional): User name for filtering in non-multi-db mode
- Returns:
- True if the edge exists, otherwise False.
- """
- # Prepare the relationship pattern
- user_name = user_name if user_name else self.config.user_name
- rel = "r" if type == "ANY" else f"r@{type}"
-
- # Prepare the match pattern with direction
- if direction == "OUTGOING":
- pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]->(b@Memory {{id: '{target_id}'}})"
- elif direction == "INCOMING":
- pattern = f"(a@Memory {{id: '{source_id}'}})<-[{rel}]-(b@Memory {{id: '{target_id}'}})"
- elif direction == "ANY":
- pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]-(b@Memory {{id: '{target_id}'}})"
- else:
- raise ValueError(
- f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'."
- )
- query = f"MATCH {pattern}"
- query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
- query += "\nRETURN r"
-
- # Run the Cypher query
- result = self.execute_query(query)
- record = result.one_or_none()
- if record is None:
- return False
- return record.values() is not None
-
- @timed
- # Graph Query & Reasoning
- def get_node(
- self, id: str, include_embedding: bool = False, user_name: str | None = None
- ) -> dict[str, Any] | None:
- """
- Retrieve a Memory node by its unique ID.
-
- Args:
- id (str): Node ID (Memory.id)
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- dict: Node properties as key-value pairs, or None if not found.
- """
- filter_clause = f'n.id = "{id}"'
- return_fields = self._build_return_fields(include_embedding)
- gql = f"""
- MATCH (n@Memory)
- WHERE {filter_clause}
- RETURN {return_fields}
- """
-
- try:
- result = self.execute_query(gql)
- for row in result:
- props = {k: v.value for k, v in row.items()}
- node = self._parse_node(props)
- return node
-
- except Exception as e:
- logger.error(
- f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}"
- )
- return None
-
- @timed
- def get_nodes(
- self,
- ids: list[str],
- include_embedding: bool = False,
- user_name: str | None = None,
- **kwargs,
- ) -> list[dict[str, Any]]:
- """
- Retrieve the metadata and memory of a list of nodes.
- Args:
- ids: List of Node identifier.
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
- Returns:
- list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
-
- Notes:
- - Assumes all provided IDs are valid and exist.
- - Returns empty list if input is empty.
- """
- if not ids:
- return []
- # Safe formatting of the ID list
- id_list = ",".join(f'"{_id}"' for _id in ids)
-
- return_fields = self._build_return_fields(include_embedding)
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE n.id IN [{id_list}]
- RETURN {return_fields}
- """
- nodes = []
- try:
- results = self.execute_query(query)
- for row in results:
- props = {k: v.value for k, v in row.items()}
- nodes.append(self._parse_node(props))
- except Exception as e:
- logger.error(
- f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}"
- )
- return nodes
-
- @timed
- def get_edges(
- self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None
- ) -> list[dict[str, str]]:
- """
- Get edges connected to a node, with optional type and direction filter.
-
- Args:
- id: Node ID to retrieve edges for.
- type: Relationship type to match, or 'ANY' to match all.
- direction: 'OUTGOING', 'INCOMING', or 'ANY'.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- List of edges:
- [
- {"from": "source_id", "to": "target_id", "type": "RELATE"},
- ...
- ]
- """
- # Build relationship type filter
- rel_type = "" if type == "ANY" else f"@{type}"
- user_name = user_name if user_name else self.config.user_name
- # Build Cypher pattern based on direction
- if direction == "OUTGOING":
- pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)"
- where_clause = f"a.id = '{id}'"
- elif direction == "INCOMING":
- pattern = f"(a@Memory)<-[r{rel_type}]-(b@Memory)"
- where_clause = f"a.id = '{id}'"
- elif direction == "ANY":
- pattern = f"(a@Memory)-[r{rel_type}]-(b@Memory)"
- where_clause = f"a.id = '{id}' OR b.id = '{id}'"
- else:
- raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")
-
- where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
-
- query = f"""
- MATCH {pattern}
- WHERE {where_clause}
- RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
- """
-
- result = self.execute_query(query)
- edges = []
- for record in result:
- edges.append(
- {
- "from": record["from_id"].value,
- "to": record["to_id"].value,
- "type": record["edge_type"].value,
- }
- )
- return edges
-
- @timed
- def get_neighbors_by_tag(
- self,
- tags: list[str],
- exclude_ids: list[str],
- top_k: int = 5,
- min_overlap: int = 1,
- include_embedding: bool = False,
- user_name: str | None = None,
- ) -> list[dict[str, Any]]:
- """
- Find top-K neighbor nodes with maximum tag overlap.
-
- Args:
- tags: The list of tags to match.
- exclude_ids: Node IDs to exclude (e.g., local cluster).
- top_k: Max number of neighbors to return.
- min_overlap: Minimum number of overlapping tags required.
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- List of dicts with node details and overlap count.
- """
- if not tags:
- return []
- user_name = user_name if user_name else self.config.user_name
- where_clauses = [
- 'n.status = "activated"',
- 'NOT (n.node_type = "reasoning")',
- 'NOT (n.memory_type = "WorkingMemory")',
- ]
- if exclude_ids:
- where_clauses.append(f"NOT (n.id IN {exclude_ids})")
-
- where_clauses.append(f'n.user_name = "{user_name}"')
-
- where_clause = " AND ".join(where_clauses)
- tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
-
- return_fields = self._build_return_fields(include_embedding)
- query = f"""
- LET tag_list = {tag_list_literal}
-
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE {where_clause}
- RETURN {return_fields},
- size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
- ORDER BY overlap_count DESC
- LIMIT {top_k}
- """
-
- result = self.execute_query(query)
- neighbors: list[dict[str, Any]] = []
- for r in result:
- props = {k: v.value for k, v in r.items() if k != "overlap_count"}
- parsed = self._parse_node(props)
- parsed["overlap_count"] = r["overlap_count"].value
- neighbors.append(parsed)
-
- neighbors.sort(key=lambda x: x["overlap_count"], reverse=True)
- neighbors = neighbors[:top_k]
- result = []
- for neighbor in neighbors[:top_k]:
- neighbor.pop("overlap_count")
- result.append(neighbor)
- return result
-
- @timed
- def get_children_with_embeddings(
- self, id: str, user_name: str | None = None
- ) -> list[dict[str, Any]]:
- user_name = user_name if user_name else self.config.user_name
- where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'"
-
- query = f"""
- MATCH (p@Memory)-[@PARENT]->(c@Memory)
- WHERE p.id = "{id}" {where_user}
- RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory
- """
- result = self.execute_query(query)
- children = []
- for row in result:
- eid = row["id"].value # STRING
- emb_v = row[self.dim_field].value # NVector
- emb = list(emb_v.values) if emb_v else []
- mem = row["memory"].value # STRING
-
- children.append({"id": eid, "embedding": emb, "memory": mem})
- return children
-
- @timed
- def get_subgraph(
- self,
- center_id: str,
- depth: int = 2,
- center_status: str = "activated",
- user_name: str | None = None,
- ) -> dict[str, Any]:
- """
- Retrieve a local subgraph centered at a given node.
- Args:
- center_id: The ID of the center node.
- depth: The hop distance for neighbors.
- center_status: Required status for center node.
- user_name (str, optional): User name for filtering in non-multi-db mode
- Returns:
- {
- "core_node": {...},
- "neighbors": [...],
- "edges": [...]
- }
- """
- if not 1 <= depth <= 5:
- raise ValueError("depth must be 1-5")
-
- user_name = user_name if user_name else self.config.user_name
-
- gql = f"""
- MATCH (center@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE center.id = '{center_id}'
- AND center.status = '{center_status}'
- AND center.user_name = '{user_name}'
- OPTIONAL MATCH p = (center)-[e]->{{1,{depth}}}(neighbor@Memory)
- WHERE neighbor.user_name = '{user_name}'
- RETURN center,
- collect(DISTINCT neighbor) AS neighbors,
- collect(EDGES(p)) AS edge_chains
- """
-
- result = self.execute_query(gql).one_or_none()
- if not result or result.size == 0:
- return {"core_node": None, "neighbors": [], "edges": []}
-
- core_node_props = result["center"].as_node().get_properties()
- core_node = self._parse_node(core_node_props)
- neighbors = []
- vid_to_id_map = {result["center"].as_node().node_id: core_node["id"]}
- for n in result["neighbors"].value:
- n_node = n.as_node()
- n_props = n_node.get_properties()
- node_parsed = self._parse_node(n_props)
- neighbors.append(node_parsed)
- vid_to_id_map[n_node.node_id] = node_parsed["id"]
-
- edges = []
- for chain_group in result["edge_chains"].value:
- for edge_wr in chain_group.value:
- edge = edge_wr.value
- edges.append(
- {
- "type": edge.get_type(),
- "source": vid_to_id_map.get(edge.get_src_id()),
- "target": vid_to_id_map.get(edge.get_dst_id()),
- }
- )
-
- return {"core_node": core_node, "neighbors": neighbors, "edges": edges}
-
- @timed
- # Search / recall operations
- def search_by_embedding(
- self,
- vector: list[float],
- top_k: int = 5,
- scope: str | None = None,
- status: str | None = None,
- threshold: float | None = None,
- search_filter: dict | None = None,
- user_name: str | None = None,
- **kwargs,
- ) -> list[dict]:
- """
- Retrieve node IDs based on vector similarity.
-
- Args:
- vector (list[float]): The embedding vector representing query semantics.
- top_k (int): Number of top similar nodes to retrieve.
- scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
- status (str, optional): Node status filter (e.g., 'active', 'archived').
- If provided, restricts results to nodes with matching status.
- threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
- search_filter (dict, optional): Additional metadata filters for search results.
- Keys should match node properties, values are the expected values.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
-
- Notes:
- - This method uses Neo4j native vector indexing to search for similar nodes.
- - If scope is provided, it restricts results to nodes with matching memory_type.
- - If 'status' is provided, only nodes with the matching status will be returned.
- - If threshold is provided, only results with score >= threshold will be returned.
- - If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
- - Typical use case: restrict to 'status = activated' to avoid
- matching archived or merged nodes.
- """
- user_name = user_name if user_name else self.config.user_name
- vector = _normalize(vector)
- dim = len(vector)
- vector_str = ",".join(f"{float(x)}" for x in vector)
- gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])"
- where_clauses = [f"n.{self.dim_field} IS NOT NULL"]
- if scope:
- where_clauses.append(f'n.memory_type = "{scope}"')
- if status:
- where_clauses.append(f'n.status = "{status}"')
- where_clauses.append(f'n.user_name = "{user_name}"')
-
- # Add search_filter conditions
- if search_filter:
- for key, value in search_filter.items():
- if isinstance(value, str):
- where_clauses.append(f'n.{key} = "{value}"')
- else:
- where_clauses.append(f"n.{key} = {value}")
-
- where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
-
- gql = f"""
- let a = {gql_vector}
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- {where_clause}
- ORDER BY inner_product(n.{self.dim_field}, a) DESC
- LIMIT {top_k}
- RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score"""
- try:
- result = self.execute_query(gql)
- except Exception as e:
- logger.error(f"[search_by_embedding] Query failed: {e}")
- return []
-
- try:
- output = []
- for row in result:
- values = row.values()
- id_val = values[0].as_string()
- score_val = values[1].as_double()
- score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
- if threshold is None or score_val >= threshold:
- output.append({"id": id_val, "score": score_val})
- return output
- except Exception as e:
- logger.error(f"[search_by_embedding] Result parse failed: {e}")
- return []
-
- @timed
- def get_by_metadata(
- self, filters: list[dict[str, Any]], user_name: str | None = None
- ) -> list[str]:
- """
- 1. ADD logic: "AND" vs "OR"(support logic combination);
- 2. Support nested conditional expressions;
-
- Retrieve node IDs that match given metadata filters.
- Supports exact match.
-
- Args:
- filters: List of filter dicts like:
- [
- {"field": "key", "op": "in", "value": ["A", "B"]},
- {"field": "confidence", "op": ">=", "value": 80},
- {"field": "tags", "op": "contains", "value": "AI"},
- ...
- ]
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
-
- Notes:
- - Supports structured querying such as tag/category/importance/time filtering.
- - Can be used for faceted recall or prefiltering before embedding rerank.
- """
- where_clauses = []
- user_name = user_name if user_name else self.config.user_name
- for _i, f in enumerate(filters):
- field = f["field"]
- op = f.get("op", "=")
- value = f["value"]
-
- escaped_value = self._format_value(value)
-
- # Build WHERE clause
- if op == "=":
- where_clauses.append(f"n.{field} = {escaped_value}")
- elif op == "in":
- where_clauses.append(f"n.{field} IN {escaped_value}")
- elif op == "contains":
- where_clauses.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0")
- elif op == "starts_with":
- where_clauses.append(f"n.{field} STARTS WITH {escaped_value}")
- elif op == "ends_with":
- where_clauses.append(f"n.{field} ENDS WITH {escaped_value}")
- elif op in [">", ">=", "<", "<="]:
- where_clauses.append(f"n.{field} {op} {escaped_value}")
- else:
- raise ValueError(f"Unsupported operator: {op}")
-
- where_clauses.append(f'n.user_name = "{user_name}"')
-
- where_str = " AND ".join(where_clauses)
- gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id"
- ids = []
- try:
- result = self.execute_query(gql)
- ids = [record["id"].value for record in result]
- except Exception as e:
- logger.error(f"Failed to get metadata: {e}, gql is {gql}")
- return ids
-
- @timed
- def get_grouped_counts(
- self,
- group_fields: list[str],
- where_clause: str = "",
- params: dict[str, Any] | None = None,
- user_name: str | None = None,
- ) -> list[dict[str, Any]]:
- """
- Count nodes grouped by any fields.
-
- Args:
- group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"]
- where_clause (str, optional): Extra WHERE condition. E.g.,
- "WHERE n.status = 'activated'"
- params (dict, optional): Parameters for WHERE clause.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...]
- """
- if not group_fields:
- raise ValueError("group_fields cannot be empty")
- user_name = user_name if user_name else self.config.user_name
- # GQL-specific modifications
- user_clause = f"n.user_name = '{user_name}'"
- if where_clause:
- where_clause = where_clause.strip()
- if where_clause.upper().startswith("WHERE"):
- where_clause += f" AND {user_clause}"
- else:
- where_clause = f"WHERE {where_clause} AND {user_clause}"
- else:
- where_clause = f"WHERE {user_clause}"
-
- # Inline parameters if provided
- if params:
- for key, value in params.items():
- # Handle different value types appropriately
- if isinstance(value, str):
- value = f"'{value}'"
- where_clause = where_clause.replace(f"${key}", str(value))
-
- return_fields = []
- group_by_fields = []
-
- for field in group_fields:
- alias = field.replace(".", "_")
- return_fields.append(f"n.{field} AS {alias}")
- group_by_fields.append(alias)
- # Full GQL query construction
- gql = f"""
- MATCH (n /*+ INDEX(idx_memory_user_name) */)
- {where_clause}
- RETURN {", ".join(return_fields)}, COUNT(n) AS count
- """
- result = self.execute_query(gql) # Pure GQL string execution
-
- output = []
- for record in result:
- group_values = {}
- for i, field in enumerate(group_fields):
- value = record.values()[i].as_string()
- group_values[field] = value
- count_value = record["count"].value
- output.append({**group_values, "count": count_value})
-
- return output
-
- @timed
- def clear(self, user_name: str | None = None) -> None:
- """
- Clear the entire graph if the target database exists.
-
- Args:
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- try:
- query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n"
- self.execute_query(query)
- logger.info("Cleared all nodes from database.")
-
- except Exception as e:
- logger.error(f"[ERROR] Failed to clear database: {e}")
-
- @timed
- def export_graph(
- self, include_embedding: bool = False, user_name: str | None = None, **kwargs
- ) -> dict[str, Any]:
- """
- Export all graph nodes and edges in a structured form.
- Args:
- include_embedding (bool): Whether to include the large embedding field.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- {
- "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
- "edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
- }
- """
- user_name = user_name if user_name else self.config.user_name
- node_query = "MATCH (n@Memory)"
- edge_query = "MATCH (a@Memory)-[r]->(b@Memory)"
- node_query += f' WHERE n.user_name = "{user_name}"'
- edge_query += f' WHERE r.user_name = "{user_name}"'
-
- try:
- if include_embedding:
- return_fields = "n"
- else:
- return_fields = ",".join(
- [
- "n.id AS id",
- "n.memory AS memory",
- "n.user_name AS user_name",
- "n.user_id AS user_id",
- "n.session_id AS session_id",
- "n.status AS status",
- "n.key AS key",
- "n.confidence AS confidence",
- "n.tags AS tags",
- "n.created_at AS created_at",
- "n.updated_at AS updated_at",
- "n.memory_type AS memory_type",
- "n.sources AS sources",
- "n.source AS source",
- "n.node_type AS node_type",
- "n.visibility AS visibility",
- "n.usage AS usage",
- "n.background AS background",
- ]
- )
-
- full_node_query = f"{node_query} RETURN {return_fields}"
- node_result = self.execute_query(full_node_query, timeout=20)
- nodes = []
- logger.debug(f"Debugging: {node_result}")
- for row in node_result:
- if include_embedding:
- props = row.values()[0].as_node().get_properties()
- else:
- props = {k: v.value for k, v in row.items()}
- node = self._parse_node(props)
- nodes.append(node)
- except Exception as e:
- raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e
-
- try:
- full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge"
- edge_result = self.execute_query(full_edge_query, timeout=20)
- edges = [
- {
- "source": row.values()[0].value,
- "target": row.values()[1].value,
- "type": row.values()[2].value,
- }
- for row in edge_result
- ]
- except Exception as e:
- raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e
-
- return {"nodes": nodes, "edges": edges}
-
- @timed
- def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None:
- """
- Import the entire graph from a serialized dictionary.
-
- Args:
- data: A dictionary containing all nodes and edges to be loaded.
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- for node in data.get("nodes", []):
- try:
- id, memory, metadata = _compose_node(node)
- metadata["user_name"] = user_name
- metadata = self._prepare_node_metadata(metadata)
- metadata.update({"id": id, "memory": memory})
- properties = ", ".join(
- f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
- )
- node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
- self.execute_query(node_gql)
- except Exception as e:
- logger.error(f"Fail to load node: {node}, error: {e}")
-
- for edge in data.get("edges", []):
- try:
- source_id, target_id = edge["source"], edge["target"]
- edge_type = edge["type"]
- props = f'{{user_name: "{user_name}"}}'
- edge_gql = f'''
- MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
- INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
- '''
- self.execute_query(edge_gql)
- except Exception as e:
- logger.error(f"Fail to load edge: {edge}, error: {e}")
-
- @timed
- def get_all_memory_items(
- self, scope: str, include_embedding: bool = False, user_name: str | None = None
- ) -> (list)[dict]:
- """
- Retrieve all memory items of a specific memory_type.
-
- Args:
- scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[dict]: Full list of memory items under this scope.
- """
- user_name = user_name if user_name else self.config.user_name
- if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}:
- raise ValueError(f"Unsupported memory type scope: {scope}")
-
- where_clause = f"WHERE n.memory_type = '{scope}'"
- where_clause += f" AND n.user_name = '{user_name}'"
-
- return_fields = self._build_return_fields(include_embedding)
-
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- {where_clause}
- RETURN {return_fields}
- LIMIT 100
- """
- nodes = []
- try:
- results = self.execute_query(query)
- for row in results:
- props = {k: v.value for k, v in row.items()}
- nodes.append(self._parse_node(props))
- except Exception as e:
- logger.error(f"Failed to get memories: {e}")
- return nodes
-
- @timed
- def get_structure_optimization_candidates(
- self, scope: str, include_embedding: bool = False, user_name: str | None = None
- ) -> list[dict]:
- """
- Find nodes that are likely candidates for structure optimization:
- - Isolated nodes, nodes with empty background, or nodes with exactly one child.
- - Plus: the child of any parent node that has exactly one child.
- """
- user_name = user_name if user_name else self.config.user_name
- where_clause = f'''
- n.memory_type = "{scope}"
- AND n.status = "activated"
- '''
- where_clause += f' AND n.user_name = "{user_name}"'
-
- return_fields = self._build_return_fields(include_embedding)
- return_fields += f", n.{self.dim_field} AS {self.dim_field}"
-
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE {where_clause}
- OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
- OPTIONAL MATCH (p@Memory)-[@PARENT]->(n)
- WHERE c IS NULL AND p IS NULL
- RETURN {return_fields}
- """
-
- candidates = []
- node_ids = set()
- try:
- results = self.execute_query(query)
- for row in results:
- props = {k: v.value for k, v in row.items()}
- node = self._parse_node(props)
- node_id = node["id"]
- if node_id not in node_ids:
- candidates.append(node)
- node_ids.add(node_id)
- except Exception as e:
- logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
- return candidates
-
- @timed
- def drop_database(self) -> None:
- """
- Permanently delete the entire database this instance is using.
- WARNING: This operation is destructive and cannot be undone.
- """
- raise ValueError(
- f"Refusing to drop protected database: `{self.db_name}` in "
- f"Shared Database Multi-Tenant mode"
- )
-
- @timed
- def detect_conflicts(self) -> list[tuple[str, str]]:
- """
- Detect conflicting nodes based on logical or semantic inconsistency.
- Returns:
- A list of (node_id1, node_id2) tuples that conflict.
- """
- raise NotImplementedError
-
- @timed
- # Structure Maintenance
- def deduplicate_nodes(self) -> None:
- """
- Deduplicate redundant or semantically similar nodes.
- This typically involves identifying nodes with identical or near-identical memory.
- """
- raise NotImplementedError
-
- @timed
- def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
- """
- Get the ordered context chain starting from a node, following a relationship type.
- Args:
- id: Starting node ID.
- type: Relationship type to follow (e.g., 'FOLLOWS').
- Returns:
- List of ordered node IDs in the chain.
- """
- raise NotImplementedError
-
- @timed
- def get_neighbors(
- self, id: str, type: str, direction: Literal["in", "out", "both"] = "out"
- ) -> list[str]:
- """
- Get connected node IDs in a specific direction and relationship type.
- Args:
- id: Source node ID.
- type: Relationship type.
- direction: Edge direction to follow ('out', 'in', or 'both').
- Returns:
- List of neighboring node IDs.
- """
- raise NotImplementedError
-
- @timed
- def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]:
- """
- Get the path of nodes from source to target within a limited depth.
- Args:
- source_id: Starting node ID.
- target_id: Target node ID.
- max_depth: Maximum path length to traverse.
- Returns:
- Ordered list of node IDs along the path.
- """
- raise NotImplementedError
-
- @timed
- def merge_nodes(self, id1: str, id2: str) -> str:
- """
- Merge two similar or duplicate nodes into one.
- Args:
- id1: First node ID.
- id2: Second node ID.
- Returns:
- ID of the resulting merged node.
- """
- raise NotImplementedError
-
- @classmethod
- def _ensure_space_exists(cls, tmp_client, cfg):
- """Lightweight check to ensure target graph (space) exists."""
- db_name = getattr(cfg, "space", None)
- if not db_name:
- logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.")
- return
-
- try:
- res = tmp_client.execute("SHOW GRAPHS")
- existing = {row.values()[0].as_string() for row in res}
- if db_name not in existing:
- tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type")
- logger.info(f"โ
Graph `{db_name}` created before session binding.")
- else:
- logger.debug(f"Graph `{db_name}` already exists.")
- except Exception:
- logger.exception("[NebulaGraphDBSync] Failed to ensure space exists")
-
- @timed
- def _ensure_database_exists(self):
- graph_type_name = "MemOSBgeM3Type"
-
- check_type_query = "SHOW GRAPH TYPES"
- result = self.execute_query(check_type_query, auto_set_db=False)
-
- type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result)
-
- if not type_exists:
- create_tag = f"""
- CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{
- NODE Memory (:MemoryTag {{
- id STRING,
- memory STRING,
- user_name STRING,
- user_id STRING,
- session_id STRING,
- status STRING,
- key STRING,
- confidence FLOAT,
- tags LIST,
- created_at STRING,
- updated_at STRING,
- memory_type STRING,
- sources LIST,
- source STRING,
- node_type STRING,
- visibility STRING,
- usage LIST,
- background STRING,
- {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>,
- PRIMARY KEY(id)
- }}),
- EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory)
- }}
- """
- self.execute_query(create_tag, auto_set_db=False)
- else:
- describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name}"
- desc_result = self.execute_query(describe_query, auto_set_db=False)
-
- memory_fields = []
- for row in desc_result:
- field_name = row.values()[0].as_string()
- memory_fields.append(field_name)
-
- if self.dim_field not in memory_fields:
- alter_query = f"""
- ALTER GRAPH TYPE {graph_type_name} {{
- ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }}
- }}
- """
- self.execute_query(alter_query, auto_set_db=False)
- logger.info(f"โ
Add new vector search {self.dim_field} to {graph_type_name}")
- else:
- logger.info(f"โ
Graph Type {graph_type_name} already include {self.dim_field}")
-
- create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
- try:
- self.execute_query(create_graph, auto_set_db=False)
- logger.info(f"โ
Graph ``{self.db_name}`` is now the working graph.")
- except Exception as e:
- logger.error(f"โ Failed to create tag: {e} trace: {traceback.format_exc()}")
-
- @timed
- def _create_vector_index(
- self,
- label: str = "Memory",
- vector_property: str = "embedding",
- dimensions: int = 3072,
- index_name: str = "memory_vector_index",
- ) -> None:
- """
- Create a vector index for the specified property in the label.
- """
- if str(dimensions) == str(self.default_memory_dimension):
- index_name = f"idx_{vector_property}"
- vector_name = vector_property
- else:
- index_name = f"idx_{vector_property}_{dimensions}"
- vector_name = f"{vector_property}_{dimensions}"
-
- create_vector_index = f"""
- CREATE VECTOR INDEX IF NOT EXISTS {index_name}
- ON NODE {label}::{vector_name}
- OPTIONS {{
- DIM: {dimensions},
- METRIC: IP,
- TYPE: IVF,
- NLIST: 100,
- TRAINSIZE: 1000
- }}
- FOR `{self.db_name}`
- """
- self.execute_query(create_vector_index)
- logger.info(
- f"โ
Ensure {label}::{vector_property} vector index {index_name} "
- f"exists (DIM={dimensions})"
- )
-
- @timed
- def _create_basic_property_indexes(self) -> None:
- """
- Create standard B-tree indexes on status, memory_type, created_at
- and updated_at fields.
- Create standard B-tree indexes on user_name when use Shared Database
- Multi-Tenant Mode.
- """
- fields = [
- "status",
- "memory_type",
- "created_at",
- "updated_at",
- "user_name",
- ]
-
- for field in fields:
- index_name = f"idx_memory_{field}"
- gql = f"""
- CREATE INDEX IF NOT EXISTS {index_name} ON NODE Memory({field})
- FOR `{self.db_name}`
- """
- try:
- self.execute_query(gql)
- logger.info(f"โ
Created index: {index_name} on field {field}")
- except Exception as e:
- logger.error(
- f"โ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}"
- )
-
- @timed
- def _index_exists(self, index_name: str) -> bool:
- """
- Check if an index with the given name exists.
- """
- """
- Check if a vector index with the given name exists in NebulaGraph.
-
- Args:
- index_name (str): The name of the index to check.
-
- Returns:
- bool: True if the index exists, False otherwise.
- """
- query = "SHOW VECTOR INDEXES"
- try:
- result = self.execute_query(query)
- return any(row.values()[0].as_string() == index_name for row in result)
- except Exception as e:
- logger.error(f"[Nebula] Failed to check index existence: {e}")
- return False
-
- @timed
- def _parse_value(self, value: Any) -> Any:
- """turn Nebula ValueWrapper to Python type"""
- from nebulagraph_python.value_wrapper import ValueWrapper
-
- if value is None or (hasattr(value, "is_null") and value.is_null()):
- return None
- try:
- prim = value.cast_primitive() if isinstance(value, ValueWrapper) else value
- except Exception as e:
- logger.warning(f"Error when decode Nebula ValueWrapper: {e}")
- prim = value.cast() if isinstance(value, ValueWrapper) else value
-
- if isinstance(prim, ValueWrapper):
- return self._parse_value(prim)
- if isinstance(prim, list):
- return [self._parse_value(v) for v in prim]
- if type(prim).__name__ == "NVector":
- return list(prim.values)
-
- return prim # already a Python primitive
-
- def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]:
- parsed = {k: self._parse_value(v) for k, v in props.items()}
-
- for tf in ("created_at", "updated_at"):
- if tf in parsed and parsed[tf] is not None:
- parsed[tf] = _normalize_datetime(parsed[tf])
-
- node_id = parsed.pop("id")
- memory = parsed.pop("memory", "")
- parsed.pop("user_name", None)
- metadata = parsed
- metadata["type"] = metadata.pop("node_type")
-
- if self.dim_field in metadata:
- metadata["embedding"] = metadata.pop(self.dim_field)
-
- return {"id": node_id, "memory": memory, "metadata": metadata}
-
- @timed
- def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """
- Ensure metadata has proper datetime fields and normalized types.
-
- - Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
- - Convert embedding to list of float if present.
- """
- now = datetime.utcnow().isoformat()
- metadata["node_type"] = metadata.pop("type")
-
- # Fill timestamps if missing
- metadata.setdefault("created_at", now)
- metadata.setdefault("updated_at", now)
-
- # Normalize embedding type
- embedding = metadata.get("embedding")
- if embedding and isinstance(embedding, list):
- metadata.pop("embedding")
- metadata[self.dim_field] = _normalize([float(x) for x in embedding])
-
- return metadata
-
- @timed
- def _format_value(self, val: Any, key: str = "") -> str:
- from nebulagraph_python.py_data_types import NVector
-
- # None
- if val is None:
- return "NULL"
- # bool
- if isinstance(val, bool):
- return "true" if val else "false"
- # str
- if isinstance(val, str):
- return f'"{_escape_str(val)}"'
- # num
- elif isinstance(val, (int | float)):
- return str(val)
- # time
- elif isinstance(val, datetime):
- return f'datetime("{val.isoformat()}")'
- # list
- elif isinstance(val, list):
- if key == self.dim_field:
- dim = len(val)
- joined = ",".join(str(float(x)) for x in val)
- return f"VECTOR<{dim}, FLOAT>([{joined}])"
- else:
- return f"[{', '.join(self._format_value(v) for v in val)}]"
- # NVector
- elif isinstance(val, NVector):
- if key == self.dim_field:
- dim = len(val)
- joined = ",".join(str(float(x)) for x in val)
- return f"VECTOR<{dim}, FLOAT>([{joined}])"
- else:
- logger.warning("Invalid NVector")
- # dict
- if isinstance(val, dict):
- j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
- return f'"{_escape_str(j)}"'
- else:
- return f'"{_escape_str(str(val))}"'
-
- @timed
- def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """
- Filter and validate metadata dictionary against the Memory node schema.
- - Removes keys not in schema.
- - Warns if required fields are missing.
- """
-
- dim_fields = {self.dim_field}
-
- allowed_fields = self.common_fields | dim_fields
-
- missing_fields = allowed_fields - metadata.keys()
- if missing_fields:
- logger.info(f"Metadata missing required fields: {sorted(missing_fields)}")
-
- filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
-
- return filtered_metadata
-
- def _build_return_fields(self, include_embedding: bool = False) -> str:
- fields = set(self.base_fields)
- if include_embedding:
- fields.add(self.dim_field)
- return ", ".join(f"n.{f} AS {f}" for f in fields)
diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py
index 4d88844df..f61d981d8 100644
--- a/src/memos/graph_dbs/polardb.py
+++ b/src/memos/graph_dbs/polardb.py
@@ -441,7 +441,7 @@ def remove_oldest_memory(
)
user_name = user_name if user_name else self._get_config_value("user_name")
- # Use actual OFFSET logic, consistent with nebular.py
+ # Use actual OFFSET logic for deterministic pruning
# First find IDs to delete, then delete them
select_query = f"""
SELECT id FROM "{self.db_name}_graph"."Memory"
@@ -3539,7 +3539,7 @@ def get_neighbors_by_tag_ccl(
user_name = user_name if user_name else self._get_config_value("user_name")
- # Build query conditions; keep consistent with nebular.py
+ # Build query conditions shared with other graph backends
where_clauses = [
'n.status = "activated"',
'NOT (n.node_type = "reasoning")',
@@ -3592,7 +3592,7 @@ def get_neighbors_by_tag_ccl(
# Add overlap_count
result_fields.append("overlap_count agtype")
result_fields_str = ", ".join(result_fields)
- # Use Cypher query; keep consistent with nebular.py
+ # Use Cypher query to keep the graph query path aligned
query = f"""
SELECT * FROM (
SELECT * FROM cypher('{self.db_name}_graph', $$
diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
index ec431c253..671190e6f 100644
--- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
+++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
@@ -47,13 +47,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
graph_db_backend_map = {
"neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id),
"neo4j": APIConfig.get_neo4j_config(user_id=user_id),
- "nebular": APIConfig.get_nebular_config(user_id=user_id),
"polardb": APIConfig.get_polardb_config(user_id=user_id),
"postgres": APIConfig.get_postgres_config(user_id=user_id),
}
# Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
- graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower()
+ graph_db_backend = os.getenv(
+ "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")
+ ).lower()
return GraphDBConfigFactory.model_validate(
{
"backend": graph_db_backend,