From f37ccc522cfd5675611a9ab2cac3aac1e708e3ac Mon Sep 17 00:00:00 2001 From: Gurveer Singh Virk <110109294+gurveervirk@users.noreply.github.com> Date: Wed, 10 Sep 2025 11:28:47 +0000 Subject: [PATCH] modified retrievers to support optional node labels --- .../retrievers/external/pinecone/pinecone.py | 4 ++++ .../retrievers/external/pinecone/types.py | 1 + .../retrievers/external/qdrant/qdrant.py | 4 ++++ .../retrievers/external/qdrant/types.py | 1 + src/neo4j_graphrag/retrievers/external/utils.py | 11 ++++++++--- .../retrievers/external/weaviate/types.py | 1 + .../retrievers/external/weaviate/weaviate.py | 4 ++++ 7 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py b/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py index 16b3ea3d6..44d111f21 100644 --- a/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py +++ b/src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py @@ -101,6 +101,7 @@ def __init__( Callable[[neo4j.Record], RetrieverResultItem] ] = None, neo4j_database: Optional[str] = None, + node_label: Optional[str] = None, ): try: driver_model = Neo4jDriverModel(driver=driver) @@ -116,6 +117,7 @@ def __init__( retrieval_query=retrieval_query, result_formatter=result_formatter, neo4j_database=neo4j_database, + node_label=node_label, ) except ValidationError as e: raise RetrieverInitializationError(e.errors()) from e @@ -138,6 +140,7 @@ def __init__( self.return_properties = validated_data.return_properties self.retrieval_query = validated_data.retrieval_query self.result_formatter = validated_data.result_formatter + self.node_label = validated_data.node_label def get_search_results( self, @@ -223,6 +226,7 @@ def get_search_results( search_query = get_match_query( return_properties=self.return_properties, retrieval_query=self.retrieval_query, + node_label=self.node_label, ) parameters = { diff --git a/src/neo4j_graphrag/retrievers/external/pinecone/types.py b/src/neo4j_graphrag/retrievers/external/pinecone/types.py index 397894fd5..ae69abe45 100644 --- a/src/neo4j_graphrag/retrievers/external/pinecone/types.py +++ b/src/neo4j_graphrag/retrievers/external/pinecone/types.py @@ -59,3 +59,4 @@ class PineconeNeo4jRetrieverModel(BaseModel): retrieval_query: Optional[str] = None result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None + node_label: Optional[str] = None diff --git a/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py b/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py index a5e47f2a3..fbdd5e6c1 100644 --- a/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py +++ b/src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py @@ -99,6 +99,7 @@ def __init__( Callable[[neo4j.Record], RetrieverResultItem] ] = None, neo4j_database: Optional[str] = None, + node_label: Optional[str] = None, ): try: driver_model = Neo4jDriverModel(driver=driver) @@ -116,6 +117,7 @@ def __init__( retrieval_query=retrieval_query, result_formatter=result_formatter, neo4j_database=neo4j_database, + node_label=node_label, ) except ValidationError as e: raise RetrieverInitializationError(e.errors()) from e @@ -138,6 +140,7 @@ def __init__( self.return_properties = validated_data.return_properties self.retrieval_query = validated_data.retrieval_query self.result_formatter = validated_data.result_formatter + self.node_label = validated_data.node_label def get_search_results( self, @@ -223,6 +226,7 @@ def get_search_results( search_query = get_match_query( return_properties=self.return_properties, retrieval_query=self.retrieval_query, + node_label=self.node_label, ) parameters = { diff --git a/src/neo4j_graphrag/retrievers/external/qdrant/types.py b/src/neo4j_graphrag/retrievers/external/qdrant/types.py index d0240c150..57e7f7269 100644 --- a/src/neo4j_graphrag/retrievers/external/qdrant/types.py +++ b/src/neo4j_graphrag/retrievers/external/qdrant/types.py @@ -54,3 +54,4 @@ class QdrantNeo4jRetrieverModel(BaseModel): retrieval_query: Optional[str] = None result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None + node_label: Optional[str] = None diff --git a/src/neo4j_graphrag/retrievers/external/utils.py b/src/neo4j_graphrag/retrievers/external/utils.py index a0a97b448..df0320c57 100644 --- a/src/neo4j_graphrag/retrievers/external/utils.py +++ b/src/neo4j_graphrag/retrievers/external/utils.py @@ -20,14 +20,19 @@ def get_match_query( - return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None + return_properties: Optional[list[str]] = None, + retrieval_query: Optional[str] = None, + node_label: Optional[str] = None, ) -> str: match_query = ( "UNWIND $match_params AS match_param " "WITH match_param[0] AS match_id_value, match_param[1] AS score " - "MATCH (node) " - "WHERE node[$id_property] = match_id_value " ) + if node_label: + match_query += f"MATCH (node:{node_label}) " + else: + match_query += "MATCH (node) " + match_query += "WHERE node[$id_property] = match_id_value " return match_query + get_query_tail( return_properties=return_properties, retrieval_query=retrieval_query, diff --git a/src/neo4j_graphrag/retrievers/external/weaviate/types.py b/src/neo4j_graphrag/retrievers/external/weaviate/types.py index 2eaa0e5c7..3c1f1cd4a 100644 --- a/src/neo4j_graphrag/retrievers/external/weaviate/types.py +++ b/src/neo4j_graphrag/retrievers/external/weaviate/types.py @@ -57,6 +57,7 @@ class WeaviateNeo4jRetrieverModel(BaseModel): retrieval_query: Optional[str] = None result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None neo4j_database: Optional[str] = None + node_label: Optional[str] = None class WeaviateNeo4jSearchModel(VectorSearchModel): diff --git a/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py b/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py index cad21443e..4fc25a04a 100644 --- a/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py +++ b/src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py @@ -100,6 +100,7 @@ def __init__( Callable[[neo4j.Record], RetrieverResultItem] ] = None, neo4j_database: Optional[str] = None, + node_label: Optional[str] = None, ): try: driver_model = Neo4jDriverModel(driver=driver) @@ -116,6 +117,7 @@ def __init__( retrieval_query=retrieval_query, result_formatter=result_formatter, neo4j_database=neo4j_database, + node_label=node_label, ) except ValidationError as e: raise RetrieverInitializationError(e.errors()) from e @@ -134,6 +136,7 @@ def __init__( self.return_properties = validated_data.return_properties self.retrieval_query = validated_data.retrieval_query self.result_formatter = validated_data.result_formatter + self.node_label = validated_data.node_label def get_search_results( self, @@ -234,6 +237,7 @@ def get_search_results( search_query = get_match_query( return_properties=self.return_properties, retrieval_query=self.retrieval_query, + node_label=self.node_label, ) parameters = {