From 0c5be2f7baeadc754884df401a109c2ae8e46a85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Fri, 17 Oct 2025 16:43:49 +0800 Subject: [PATCH] fix: nebula efficiency --- src/memos/graph_dbs/nebular.py | 264 ++++++++++++++++----------------- 1 file changed, 131 insertions(+), 133 deletions(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 66ad894ad..150d84ccf 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -169,7 +169,7 @@ def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> " if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension) else "embedding" ) - tmp.system_db_name = "system" if getattr(cfg, "use_multi_db", False) else cfg.space + tmp.system_db_name = cfg.space tmp._client = client tmp._owns_client = False return tmp @@ -188,6 +188,19 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "N client = cls._CLIENT_CACHE.get(key) if client is None: # Connection setting + + tmp_client = NebulaClient( + hosts=cfg.uri, + username=cfg.user, + password=cfg.password, + session_config=SessionConfig(graph=None), + session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000), + ) + try: + cls._ensure_space_exists(tmp_client, cfg) + finally: + tmp_client.close() + conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None) if conn_conf is None: conn_conf = ConnectionConfig.from_defults( @@ -318,6 +331,7 @@ def __init__(self, config: NebulaGraphDBConfig): } """ + assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED" self.config = config self.db_name = config.space self.user_name = config.user_name @@ -350,7 +364,7 @@ def __init__(self, config: NebulaGraphDBConfig): if (str(self.embedding_dimension) != str(self.default_memory_dimension)) else "embedding" ) - self.system_db_name = "system" if config.use_multi_db else config.space + self.system_db_name = config.space # ---- NEW: pool acquisition strategy # Get or create a shared pool from the class-level cache @@ -425,27 +439,29 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). keep_latest (int): Number of latest WorkingMemory entries to keep. """ - optional_condition = "" - if not self.config.use_multi_db and self.config.user_name: - optional_condition = f"AND n.user_name = '{self.config.user_name}'" + optional_condition = f"AND n.user_name = '{self.config.user_name}'" - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = '{memory_type}' - {optional_condition} - ORDER BY n.updated_at DESC - OFFSET {keep_latest} - DETACH DELETE n - """ - self.execute_query(query) + try: + count = self.count_nodes(memory_type) + if count > keep_latest: + delete_query = f""" + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) + WHERE n.memory_type = '{memory_type}' + {optional_condition} + ORDER BY n.updated_at DESC + OFFSET {keep_latest} + DETACH DELETE n + """ + self.execute_query(delete_query) + except Exception as e: + logger.warning(f"Delete old mem error: {e}") @timed def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: """ Insert or update a Memory node in NebulaGraph. """ - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name + metadata["user_name"] = self.config.user_name now = datetime.utcnow() metadata = metadata.copy() @@ -476,12 +492,9 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: @timed def node_not_exist(self, scope: str) -> int: - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"' - else: - filter_clause = f'n.memory_type = "{scope}"' + filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{self.config.user_name}"' query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {filter_clause} RETURN n.id AS id LIMIT 1 @@ -510,8 +523,7 @@ def update_node(self, id: str, fields: dict[str, Any]) -> None: MATCH (n@Memory {{id: "{id}"}}) """ - if not self.config.use_multi_db and self.config.user_name: - query += f'WHERE n.user_name = "{self.config.user_name}"' + query += f'WHERE n.user_name = "{self.config.user_name}"' query += f"\nSET {set_clause_str}" self.execute_query(query) @@ -526,9 +538,8 @@ def delete_node(self, id: str) -> None: query = f""" MATCH (n@Memory {{id: "{id}"}}) """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" WHERE n.user_name = {self._format_value(user_name)}" + user_name = self.config.user_name + query += f" WHERE n.user_name = {self._format_value(user_name)}" query += "\n DETACH DELETE n" self.execute_query(query) @@ -544,9 +555,7 @@ def add_edge(self, source_id: str, target_id: str, type: str): if not source_id or not target_id: raise ValueError("[add_edge] source_id and target_id must be provided") - props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' + props = f'{{user_name: "{self.config.user_name}"}}' insert_stmt = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) @@ -571,9 +580,8 @@ def delete_edge(self, source_id: str, target_id: str, type: str) -> None: WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)} """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" + user_name = self.config.user_name + query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}" query += "\nDELETE r" self.execute_query(query) @@ -584,9 +592,8 @@ def get_memory_count(self, memory_type: str) -> int: MATCH (n@Memory) WHERE n.memory_type = "{memory_type}" """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + user_name = self.config.user_name + query += f"\nAND n.user_name = '{user_name}'" query += "\nRETURN COUNT(n) AS count" try: @@ -597,14 +604,18 @@ def get_memory_count(self, memory_type: str) -> int: return -1 @timed - def count_nodes(self, scope: str) -> int: - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = "{scope}" - """ - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + def count_nodes(self, scope: str | None = None) -> int: + query = "MATCH (n@Memory)" + conditions = [] + + if scope: + conditions.append(f'n.memory_type = "{scope}"') + user_name = self.config.user_name + conditions.append(f"n.user_name = '{user_name}'") + + if conditions: + query += "\nWHERE " + " AND ".join(conditions) + query += "\nRETURN count(n) AS count" result = self.execute_query(query) @@ -640,9 +651,8 @@ def edge_exists( f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." ) query = f"MATCH {pattern}" - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" + user_name = self.config.user_name + query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'" query += "\nRETURN r" # Run the Cypher query @@ -665,10 +675,7 @@ def get_node(self, id: str, include_embedding: bool = False) -> dict[str, Any] | Returns: dict: Node properties as key-value pairs, or None if not found. """ - if not self.config.use_multi_db and self.config.user_name: - filter_clause = f'n.user_name = "{self.config.user_name}" AND n.id = "{id}"' - else: - filter_clause = f'n.id = "{id}"' + filter_clause = f'n.id = "{id}"' return_fields = self._build_return_fields(include_embedding) gql = f""" @@ -709,19 +716,17 @@ def get_nodes( if not ids: return [] - where_user = "" - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_user = f" AND n.user_name = '{kwargs['cube_name']}'" - else: - where_user = f" AND n.user_name = '{self.config.user_name}'" + if kwargs.get("cube_name"): + where_user = f" AND n.user_name = '{kwargs['cube_name']}'" + else: + where_user = f" AND n.user_name = '{self.config.user_name}'" # Safe formatting of the ID list id_list = ",".join(f'"{_id}"' for _id in ids) return_fields = self._build_return_fields(include_embedding) query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE n.id IN [{id_list}] {where_user} RETURN {return_fields} """ @@ -770,8 +775,7 @@ def get_edges(self, id: str, type: str = "ANY", direction: str = "ANY") -> list[ else: raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND a.user_name = '{self.config.user_name}' AND b.user_name = '{self.config.user_name}'" + where_clause += f" AND a.user_name = '{self.config.user_name}' AND b.user_name = '{self.config.user_name}'" query = f""" MATCH {pattern} @@ -824,8 +828,7 @@ def get_neighbors_by_tag( if exclude_ids: where_clauses.append(f"NOT (n.id IN {exclude_ids})") - if not self.config.use_multi_db and self.config.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{self.config.user_name}"') where_clause = " AND ".join(where_clauses) tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]" @@ -834,7 +837,7 @@ def get_neighbors_by_tag( query = f""" LET tag_list = {tag_list_literal} - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_clause} RETURN {return_fields}, size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count @@ -860,11 +863,8 @@ def get_neighbors_by_tag( @timed def get_children_with_embeddings(self, id: str) -> list[dict[str, Any]]: - where_user = "" - - if not self.config.use_multi_db and self.config.user_name: - user_name = self.config.user_name - where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" + user_name = self.config.user_name + where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'" query = f""" MATCH (p@Memory)-[@PARENT]->(c@Memory) @@ -904,7 +904,7 @@ def get_subgraph( user_name = self.config.user_name gql = f""" - MATCH (center@Memory) + MATCH (center@Memory /*+ INDEX(idx_memory_user_name) */) WHERE center.id = '{center_id}' AND center.status = '{center_status}' AND center.user_name = '{user_name}' @@ -985,38 +985,33 @@ def search_by_embedding( dim = len(vector) vector_str = ",".join(f"{float(x)}" for x in vector) gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])" - - where_clauses = [] + where_clauses = [f"n.{self.dim_field} IS NOT NULL"] if scope: where_clauses.append(f'n.memory_type = "{scope}"') if status: where_clauses.append(f'n.status = "{status}"') - if not self.config.use_multi_db and self.config.user_name: - if kwargs.get("cube_name"): - where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"') - else: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + if kwargs.get("cube_name"): + where_clauses.append(f'n.user_name = "{kwargs["cube_name"]}"') + else: + where_clauses.append(f'n.user_name = "{self.config.user_name}"') - # Add search_filter conditions - if search_filter: - for key, value in search_filter.items(): - if isinstance(value, str): - where_clauses.append(f'n.{key} = "{value}"') - else: - where_clauses.append(f"n.{key} = {value}") + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append(f'n.{key} = "{value}"') + else: + where_clauses.append(f"n.{key} = {value}") where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" gql = f""" - MATCH (n@Memory) + let a = {gql_vector} + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) {where_clause} - ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC - APPROXIMATE + ORDER BY inner_product(n.{self.dim_field}, a) DESC LIMIT {top_k} - OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }} - RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score - """ - + RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score""" try: result = self.execute_query(gql) except Exception as e: @@ -1087,11 +1082,10 @@ def get_by_metadata(self, filters: list[dict[str, Any]]) -> list[str]: else: raise ValueError(f"Unsupported operator: {op}") - if not self.config.use_multi_db and self.user_name: - where_clauses.append(f'n.user_name = "{self.config.user_name}"') + where_clauses.append(f'n.user_name = "{self.config.user_name}"') where_str = " AND ".join(where_clauses) - gql = f"MATCH (n@Memory) WHERE {where_str} RETURN n.id AS id" + gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id" ids = [] try: result = self.execute_query(gql) @@ -1123,16 +1117,15 @@ def get_grouped_counts( raise ValueError("group_fields cannot be empty") # GQL-specific modifications - if not self.config.use_multi_db and self.config.user_name: - user_clause = f"n.user_name = '{self.config.user_name}'" - if where_clause: - where_clause = where_clause.strip() - if where_clause.upper().startswith("WHERE"): - where_clause += f" AND {user_clause}" - else: - where_clause = f"WHERE {where_clause} AND {user_clause}" + user_clause = f"n.user_name = '{self.config.user_name}'" + if where_clause: + where_clause = where_clause.strip() + if where_clause.upper().startswith("WHERE"): + where_clause += f" AND {user_clause}" else: - where_clause = f"WHERE {user_clause}" + where_clause = f"WHERE {where_clause} AND {user_clause}" + else: + where_clause = f"WHERE {user_clause}" # Inline parameters if provided if params: @@ -1151,7 +1144,7 @@ def get_grouped_counts( group_by_fields.append(alias) # Full GQL query construction gql = f""" - MATCH (n) + MATCH (n /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {", ".join(return_fields)}, COUNT(n) AS count GROUP BY {", ".join(group_by_fields)} @@ -1175,10 +1168,9 @@ def clear(self) -> None: Clear the entire graph if the target database exists. """ try: - if not self.config.use_multi_db and self.config.user_name: - query = f"MATCH (n@Memory) WHERE n.user_name = '{self.config.user_name}' DETACH DELETE n" - else: - query = "MATCH (n) DETACH DELETE n" + query = ( + f"MATCH (n@Memory) WHERE n.user_name = '{self.config.user_name}' DETACH DELETE n" + ) self.execute_query(query) logger.info("Cleared all nodes from database.") @@ -1202,10 +1194,9 @@ def export_graph(self, include_embedding: bool = False) -> dict[str, Any]: node_query = "MATCH (n@Memory)" edge_query = "MATCH (a@Memory)-[r]->(b@Memory)" - if not self.config.use_multi_db and self.config.user_name: - username = self.config.user_name - node_query += f' WHERE n.user_name = "{username}"' - edge_query += f' WHERE r.user_name = "{username}"' + username = self.config.user_name + node_query += f' WHERE n.user_name = "{username}"' + edge_query += f' WHERE r.user_name = "{username}"' try: if include_embedding: @@ -1276,8 +1267,7 @@ def import_graph(self, data: dict[str, Any]) -> None: try: id, memory, metadata = _compose_node(node) - if not self.config.use_multi_db and self.config.user_name: - metadata["user_name"] = self.config.user_name + metadata["user_name"] = self.config.user_name metadata = self._prepare_node_metadata(metadata) metadata.update({"id": id, "memory": memory}) @@ -1293,9 +1283,7 @@ def import_graph(self, data: dict[str, Any]) -> None: try: source_id, target_id = edge["source"], edge["target"] edge_type = edge["type"] - props = "" - if not self.config.use_multi_db and self.config.user_name: - props = f'{{user_name: "{self.config.user_name}"}}' + props = f'{{user_name: "{self.config.user_name}"}}' edge_gql = f''' MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}}) INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b) @@ -1320,14 +1308,12 @@ def get_all_memory_items(self, scope: str, include_embedding: bool = False) -> ( raise ValueError(f"Unsupported memory type scope: {scope}") where_clause = f"WHERE n.memory_type = '{scope}'" - - if not self.config.use_multi_db and self.config.user_name: - where_clause += f" AND n.user_name = '{self.config.user_name}'" + where_clause += f" AND n.user_name = '{self.config.user_name}'" return_fields = self._build_return_fields(include_embedding) query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) {where_clause} RETURN {return_fields} LIMIT 100 @@ -1356,14 +1342,13 @@ def get_structure_optimization_candidates( n.memory_type = "{scope}" AND n.status = "activated" ''' - if not self.config.use_multi_db and self.config.user_name: - where_clause += f' AND n.user_name = "{self.config.user_name}"' + where_clause += f' AND n.user_name = "{self.config.user_name}"' return_fields = self._build_return_fields(include_embedding) return_fields += f", n.{self.dim_field} AS {self.dim_field}" query = f""" - MATCH (n@Memory) + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_clause} OPTIONAL MATCH (n)-[@PARENT]->(c@Memory) OPTIONAL MATCH (p@Memory)-[@PARENT]->(n) @@ -1392,14 +1377,10 @@ def drop_database(self) -> None: Permanently delete the entire database this instance is using. WARNING: This operation is destructive and cannot be undone. """ - if self.config.use_multi_db: - self.execute_query(f"DROP GRAPH `{self.db_name}`") - logger.info(f"Database '`{self.db_name}`' has been dropped.") - else: - raise ValueError( - f"Refusing to drop protected database: `{self.db_name}` in " - f"Shared Database Multi-Tenant mode" - ) + raise ValueError( + f"Refusing to drop protected database: `{self.db_name}` in " + f"Shared Database Multi-Tenant mode" + ) @timed def detect_conflicts(self) -> list[tuple[str, str]]: @@ -1471,6 +1452,25 @@ def merge_nodes(self, id1: str, id2: str) -> str: """ raise NotImplementedError + @classmethod + def _ensure_space_exists(cls, tmp_client, cfg): + """Lightweight check to ensure target graph (space) exists.""" + db_name = getattr(cfg, "space", None) + if not db_name: + logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.") + return + + try: + res = tmp_client.execute("SHOW GRAPHS;") + existing = {row.values()[0].as_string() for row in res} + if db_name not in existing: + tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type;") + logger.info(f"✅ Graph `{db_name}` created before session binding.") + else: + logger.debug(f"Graph `{db_name}` already exists.") + except Exception: + logger.exception("[NebulaGraphDBSync] Failed to ensure space exists") + @timed def _ensure_database_exists(self): graph_type_name = "MemOSBgeM3Type" @@ -1585,9 +1585,7 @@ def _create_basic_property_indexes(self) -> None: Create standard B-tree indexes on user_name when use Shared Database Multi-Tenant Mode. """ - fields = ["status", "memory_type", "created_at", "updated_at"] - if not self.config.use_multi_db: - fields.append("user_name") + fields = ["status", "memory_type", "created_at", "updated_at", "user_name"] for field in fields: index_name = f"idx_memory_{field}"