Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/neo4j_graphrag/retrievers/external/pinecone/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
node_label: Optional[str] = None,
):
try:
driver_model = Neo4jDriverModel(driver=driver)
Expand All @@ -116,6 +117,7 @@ def __init__(
retrieval_query=retrieval_query,
result_formatter=result_formatter,
neo4j_database=neo4j_database,
node_label=node_label,
)
except ValidationError as e:
raise RetrieverInitializationError(e.errors()) from e
Expand All @@ -138,6 +140,7 @@ def __init__(
self.return_properties = validated_data.return_properties
self.retrieval_query = validated_data.retrieval_query
self.result_formatter = validated_data.result_formatter
self.node_label = validated_data.node_label

def get_search_results(
self,
Expand Down Expand Up @@ -223,6 +226,7 @@ def get_search_results(
search_query = get_match_query(
return_properties=self.return_properties,
retrieval_query=self.retrieval_query,
node_label=self.node_label,
)

parameters = {
Expand Down
1 change: 1 addition & 0 deletions src/neo4j_graphrag/retrievers/external/pinecone/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ class PineconeNeo4jRetrieverModel(BaseModel):
retrieval_query: Optional[str] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None
node_label: Optional[str] = None
4 changes: 4 additions & 0 deletions src/neo4j_graphrag/retrievers/external/qdrant/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
node_label: Optional[str] = None,
):
try:
driver_model = Neo4jDriverModel(driver=driver)
Expand All @@ -116,6 +117,7 @@ def __init__(
retrieval_query=retrieval_query,
result_formatter=result_formatter,
neo4j_database=neo4j_database,
node_label=node_label,
)
except ValidationError as e:
raise RetrieverInitializationError(e.errors()) from e
Expand All @@ -138,6 +140,7 @@ def __init__(
self.return_properties = validated_data.return_properties
self.retrieval_query = validated_data.retrieval_query
self.result_formatter = validated_data.result_formatter
self.node_label = validated_data.node_label

def get_search_results(
self,
Expand Down Expand Up @@ -223,6 +226,7 @@ def get_search_results(
search_query = get_match_query(
return_properties=self.return_properties,
retrieval_query=self.retrieval_query,
node_label=self.node_label,
)

parameters = {
Expand Down
1 change: 1 addition & 0 deletions src/neo4j_graphrag/retrievers/external/qdrant/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,4 @@ class QdrantNeo4jRetrieverModel(BaseModel):
retrieval_query: Optional[str] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None
node_label: Optional[str] = None
11 changes: 8 additions & 3 deletions src/neo4j_graphrag/retrievers/external/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@


def get_match_query(
return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None
return_properties: Optional[list[str]] = None,
retrieval_query: Optional[str] = None,
node_label: Optional[str] = None,
) -> str:
match_query = (
"UNWIND $match_params AS match_param "
"WITH match_param[0] AS match_id_value, match_param[1] AS score "
"MATCH (node) "
"WHERE node[$id_property] = match_id_value "
)
if node_label:
match_query += f"MATCH (node:{node_label}) "
else:
match_query += "MATCH (node) "
match_query += "WHERE node[$id_property] = match_id_value "
return match_query + get_query_tail(
return_properties=return_properties,
retrieval_query=retrieval_query,
Expand Down
1 change: 1 addition & 0 deletions src/neo4j_graphrag/retrievers/external/weaviate/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class WeaviateNeo4jRetrieverModel(BaseModel):
retrieval_query: Optional[str] = None
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
neo4j_database: Optional[str] = None
node_label: Optional[str] = None


class WeaviateNeo4jSearchModel(VectorSearchModel):
Expand Down
4 changes: 4 additions & 0 deletions src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
Callable[[neo4j.Record], RetrieverResultItem]
] = None,
neo4j_database: Optional[str] = None,
node_label: Optional[str] = None,
):
try:
driver_model = Neo4jDriverModel(driver=driver)
Expand All @@ -116,6 +117,7 @@ def __init__(
retrieval_query=retrieval_query,
result_formatter=result_formatter,
neo4j_database=neo4j_database,
node_label=node_label,
)
except ValidationError as e:
raise RetrieverInitializationError(e.errors()) from e
Expand All @@ -134,6 +136,7 @@ def __init__(
self.return_properties = validated_data.return_properties
self.retrieval_query = validated_data.retrieval_query
self.result_formatter = validated_data.result_formatter
self.node_label = validated_data.node_label

def get_search_results(
self,
Expand Down Expand Up @@ -234,6 +237,7 @@ def get_search_results(
search_query = get_match_query(
return_properties=self.return_properties,
retrieval_query=self.retrieval_query,
node_label=self.node_label,
)

parameters = {
Expand Down
Loading