diff --git a/memgraph-toolbox/README.md b/memgraph-toolbox/README.md index 3602819..842bfff 100644 --- a/memgraph-toolbox/README.md +++ b/memgraph-toolbox/README.md @@ -19,6 +19,8 @@ Below is a list of tools included in the toolbox, along with their descriptions: 7. `CypherTool` - Executes arbitrary [Cypher queries](https://memgraph.com/docs/querying) on a Memgraph database. 8. `ShowConstraintInfoTool` - Shows [constraint](https://memgraph.com/docs/fundamentals/constraints) information from a Memgraph database. 9. `ShowConfigTool` - Shows [configuration](https://memgraph.com/docs/database-management/configuration) information from a Memgraph database. +10. `NodeVectorSearchTool` - Searches the most similar nodes using the Memgraph's [vector search](https://memgraph.com/docs/querying/vector-search). +11. `NodeNeighborhoodTool` - Searches for the data attached to a given node using Memgraph's [deep-path traversals](https://memgraph.com/docs/advanced-algorithms/deep-path-traversal). ## Usage diff --git a/memgraph-toolbox/src/memgraph_toolbox/memgraph_toolbox.py b/memgraph-toolbox/src/memgraph_toolbox/memgraph_toolbox.py index 6602a97..0c0504f 100644 --- a/memgraph-toolbox/src/memgraph_toolbox/memgraph_toolbox.py +++ b/memgraph-toolbox/src/memgraph_toolbox/memgraph_toolbox.py @@ -7,6 +7,7 @@ from .tools.constraint import ShowConstraintInfoTool from .tools.cypher import CypherTool from .tools.index import ShowIndexInfoTool +from .tools.node_neighborhood import NodeNeighborhoodTool from .tools.node_vector_search import NodeVectorSearchTool from .tools.page_rank import PageRankTool from .tools.schema import ShowSchemaInfoTool @@ -37,6 +38,7 @@ def __init__(self, db: Memgraph): self.add_tool(ShowConstraintInfoTool(db)) self.add_tool(CypherTool(db)) self.add_tool(ShowIndexInfoTool(db)) + self.add_tool(NodeNeighborhoodTool(db)) self.add_tool(NodeVectorSearchTool(db)) self.add_tool(PageRankTool(db)) self.add_tool(ShowSchemaInfoTool(db)) diff --git a/memgraph-toolbox/src/memgraph_toolbox/tests/test_toolbox.py b/memgraph-toolbox/src/memgraph_toolbox/tests/test_toolbox.py index d9715dc..0c42513 100644 --- a/memgraph-toolbox/src/memgraph_toolbox/tests/test_toolbox.py +++ b/memgraph-toolbox/src/memgraph_toolbox/tests/test_toolbox.py @@ -51,7 +51,7 @@ def test_memgraph_toolbox(): tools = toolkit.get_all_tools() # Check if we have all 9 tools - assert len(tools) == 10 + assert len(tools) == 11 # Check for specific tool names tool_names = [tool.name for tool in tools] @@ -66,6 +66,7 @@ def test_memgraph_toolbox(): "show_schema_info", "show_storage_info", "show_triggers", + "node_neighborhood", ] for expected_tool in expected_tools: diff --git a/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py b/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py index 5a5930b..86453c8 100644 --- a/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py +++ b/memgraph-toolbox/src/memgraph_toolbox/tests/test_tools.py @@ -6,6 +6,7 @@ from ..tools.constraint import ShowConstraintInfoTool from ..tools.cypher import CypherTool from ..tools.index import ShowIndexInfoTool +from ..tools.node_neighborhood import NodeNeighborhoodTool from ..tools.node_vector_search import NodeVectorSearchTool from ..tools.page_rank import PageRankTool from ..tools.schema import ShowSchemaInfoTool @@ -282,3 +283,31 @@ def test_node_vector_search_tool(): 'MATCH (n:Person) WHERE "embedding" IN keys(n) DETACH DELETE n' ) memgraph_client.query("DROP VECTOR INDEX my_index") + + +def test_node_neighborhood_tool(): + """Test the NodeNeighborhood tool.""" + url = "bolt://localhost:7687" + user = "" + password = "" + memgraph_client = Memgraph(url=url, username=user, password=password) + + label = "TestNodeNeighborhoodToolLabel" + memgraph_client.query(f"MATCH (n:{label}) DETACH DELETE n;") + memgraph_client.query( + f"CREATE (p1:{label} {{id: 1}})-[:KNOWS]->(p2:{label} {{id: 2}}), (p2)-[:KNOWS]->(p3:{label} {{id: 3}});" + ) + memgraph_client.query( + f"CREATE (p4:{label} {{id: 4}})-[:KNOWS]->(p5:{label} {{id: 5}});" + ) + ids = memgraph_client.query( + f"MATCH (p1:{label} {{id:1}}) RETURN id(p1) AS node_id;" + ) + assert len(ids) == 1 + node_id = ids[0]["node_id"] + + node_neighborhood_tool = NodeNeighborhoodTool(db=memgraph_client) + result = node_neighborhood_tool.call({"node_id": node_id, "max_distance": 2}) + assert isinstance(result, list) + assert len(result) == 2 + memgraph_client.query(f"MATCH (n:{label}) DETACH DELETE n;") diff --git a/memgraph-toolbox/src/memgraph_toolbox/tools/__init__.py b/memgraph-toolbox/src/memgraph_toolbox/tools/__init__.py index e69de29..8b13789 100644 --- a/memgraph-toolbox/src/memgraph_toolbox/tools/__init__.py +++ b/memgraph-toolbox/src/memgraph_toolbox/tools/__init__.py @@ -0,0 +1 @@ + diff --git a/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py b/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py new file mode 100644 index 0000000..ced337e --- /dev/null +++ b/memgraph-toolbox/src/memgraph_toolbox/tools/node_neighborhood.py @@ -0,0 +1,59 @@ +from typing import Any, Dict, List + +from ..api.memgraph import Memgraph +from ..api.tool import BaseTool + + +class NodeNeighborhoodTool(BaseTool): + """ + Tool for finding nodes within a specified neighborhood distance in Memgraph. + """ + + def __init__(self, db: Memgraph): + super().__init__( + name="node_neighborhood", + description=( + "Finds nodes within a specified distance from a given node. " + "This tool explores the graph neighborhood around a starting node, " + "returning all nodes and relationships found within the specified radius." + ), + input_schema={ + "type": "object", + "properties": { + "node_id": { + "type": "string", + "description": "The ID of the starting node to find neighborhood around", + }, + "max_distance": { + "type": "integer", + "description": "Maximum distance (hops) to search from the starting node. Default is 1.", + "default": 1, + }, + "limit": { + "type": "integer", + "description": "Maximum number of nodes to return. Default is 100.", + "default": 100, + }, + }, + "required": ["node_id"], + }, + ) + self.db = db + + def call(self, arguments: Dict[str, Any]) -> List[Dict[str, Any]]: + """Execute the neighborhood search and return the results.""" + node_id = arguments["node_id"] + max_distance = arguments.get("max_distance", 1) + limit = arguments.get("limit", 100) + + query = f"""MATCH (n)-[r*..{max_distance}]-(m) WHERE id(n) = {node_id} RETURN DISTINCT m LIMIT {limit};""" + try: + results = self.db.query(query, {}) + processed_results = [] + for record in results: + node_data = record["m"]; + properties = {k: v for k, v in node_data.items()} + processed_results.append(properties) + return processed_results + except Exception as e: + return [{"error": f"Failed to find neighborhood: {str(e)}"}]