diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a4b5e897..dcf7e1e6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ - Support for conversations with message history, including a new `message_history` parameter for LLM interactions. - Ability to include system instructions in LLM invoke method. - Summarization of chat history to enhance query embedding and context handling in GraphRAG. +- Support for thresholding on vector and fulltext indexes in Hybrid retrievers, enabling users to set importance levels for search results. ### Changed - Updated LLM implementations to handle message history consistently across providers. diff --git a/examples/retrieve/hybrid_retriever.py b/examples/retrieve/hybrid_retriever.py index 51e5e6686..9c5d766ed 100644 --- a/examples/retrieve/hybrid_retriever.py +++ b/examples/retrieve/hybrid_retriever.py @@ -37,7 +37,14 @@ # Perform the similarity search for a text query # (retrieve the top 5 most similar nodes) query_text = "Find me a movie about aliens" - print(retriever.search(query_text=query_text, top_k=5)) + results = retriever.search( + query_text=query_text, + top_k=5, + threshold_vector_index=0.1, + threshold_fulltext_index=0.8, + ) + + print(results.items[0].metadata) # note: it is also possible to query from a query_vector directly: # query_vector: list[float] = [...] diff --git a/src/neo4j_graphrag/neo4j_queries.py b/src/neo4j_graphrag/neo4j_queries.py index c689514b4..bf7ca8f3a 100644 --- a/src/neo4j_graphrag/neo4j_queries.py +++ b/src/neo4j_graphrag/neo4j_queries.py @@ -116,33 +116,50 @@ ) -def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str: +def _get_hybrid_query( + neo4j_version_is_5_23_or_above: bool, + threshold_vector_index: float = 0.0, + threshold_fulltext_index: float = 0.0, +) -> str: + vector_where_clause = "WHERE score > 0" if threshold_vector_index > 0 else "" + fulltext_where_clause = "WHERE score > 0" if threshold_fulltext_index > 0 else "" + if neo4j_version_is_5_23_or_above: - return ( - f"CALL () {{ {VECTOR_INDEX_QUERY} " - f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score " - f"UNWIND nodes AS n " - f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score " - f"UNION " - f"{FULL_TEXT_SEARCH_QUERY} " - f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score " - f"UNWIND nodes AS n " - f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} " - f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k" - ) + return f"""CALL () {{ + {VECTOR_INDEX_QUERY} + WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score + UNWIND nodes AS n + WITH n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index + THEN (n.score / vector_index_max_score) ELSE 0 END AS score {vector_where_clause} + RETURN node, score + UNION + {FULL_TEXT_SEARCH_QUERY} + WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score + UNWIND nodes AS n + WITH n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index + THEN (n.score / ft_index_max_score) ELSE 0 END AS score {fulltext_where_clause} + RETURN node, score +}} +WITH node, max(score) AS score +ORDER BY score DESC LIMIT $top_k""" else: - return ( - f"CALL {{ {VECTOR_INDEX_QUERY} " - f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score " - f"UNWIND nodes AS n " - f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score " - f"UNION " - f"{FULL_TEXT_SEARCH_QUERY} " - f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score " - f"UNWIND nodes AS n " - f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} " - f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k" - ) + return f"""CALL {{ + {VECTOR_INDEX_QUERY} + WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score + UNWIND nodes AS n + WITH n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index + THEN (n.score / vector_index_max_score) ELSE 0 END AS score{vector_where_clause} + RETURN node, score + UNION + {FULL_TEXT_SEARCH_QUERY} + WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score + UNWIND nodes AS n + WITH n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index + THEN (n.score / ft_index_max_score) ELSE 0 END AS score{fulltext_where_clause} + RETURN node, score +}} +WITH node, max(score) AS score +ORDER BY score DESC LIMIT $top_k""" def _get_filtered_vector_query( @@ -186,7 +203,6 @@ def get_search_query( neo4j_version_is_5_23_or_above: bool = False, ) -> tuple[str, dict[str, Any]]: """Build the search query, including pre-filtering if needed, and return clause. - Args search_type: Search type we want to search for: return_properties (list[str]): list of property names to return. @@ -197,10 +213,8 @@ def get_search_query( embedding_node_property (str): the name of the property holding the embeddings embedding_dimension (int): the dimension of the embeddings filters (dict[str, Any]): filters used to pre-filter the nodes before vector search - Returns: tuple[str, dict[str, Any]]: query and parameters - """ warnings.warn( "The default returned 'id' field in the search results will be removed. Please switch to using 'elementId' instead.", diff --git a/src/neo4j_graphrag/retrievers/hybrid.py b/src/neo4j_graphrag/retrievers/hybrid.py index 4634b8a06..0bb54c449 100644 --- a/src/neo4j_graphrag/retrievers/hybrid.py +++ b/src/neo4j_graphrag/retrievers/hybrid.py @@ -48,6 +48,12 @@ class HybridRetriever(Retriever): """ Provides retrieval method using combination of vector search over embeddings and fulltext search. + + The retriever uses both the vector and fulltext indexes. It uses the user query to + search both indexes, retrieving nodes and their corresponding scores. + After normalizing the scores from each set of results, it merges them, ranks the + combined results by score, and returns the top matches. + If an embedder is provided, it needs to have the required Embedder type. Example: @@ -141,6 +147,8 @@ def get_search_results( query_text: str, query_vector: Optional[list[float]] = None, top_k: int = 5, + threshold_vector_index: float = 0.0, + threshold_fulltext_index: float = 0.0, ) -> RawSearchResult: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. Both query_vector and query_text can be provided. @@ -159,6 +167,8 @@ def get_search_results( query_text (str): The text to get the closest neighbors of. query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None. top_k (int, optional): The number of neighbors to return. Defaults to 5. + threshold_vector_index (float, optional): The minimum normalized score from the vector index to include in the top k search. + threshold_fulltext_index (float, optional): The minimum normalized score from the fulltext index to include in the top k search. Raises: SearchValidationError: If validation of the input arguments fail. @@ -180,6 +190,9 @@ def get_search_results( parameters["vector_index_name"] = self.vector_index_name parameters["fulltext_index_name"] = self.fulltext_index_name + parameters["threshold_vector_index"] = threshold_vector_index + parameters["threshold_fulltext_index"] = threshold_fulltext_index + if query_text and not query_vector: if not self.embedder: raise EmbeddingRequiredError( @@ -216,6 +229,12 @@ class HybridCypherRetriever(Retriever): Provides retrieval method using combination of vector search over embeddings and fulltext search, augmented by a Cypher query. This retriever builds on HybridRetriever. + + The retriever uses both the vector and fulltext indexes. It uses the user query to + search both indexes, retrieving nodes and their corresponding scores. + After normalizing the scores from each set of results, it merges them, ranks the + combined results by score, and returns the top matches. + If an embedder is provided, it needs to have the required Embedder type. Note: `node` is a variable from the base query that can be used in `retrieval_query` as seen in the example below. @@ -296,6 +315,8 @@ def get_search_results( query_vector: Optional[list[float]] = None, top_k: int = 5, query_params: Optional[dict[str, Any]] = None, + threshold_vector_index: float = 0.0, + threshold_fulltext_index: float = 0.0, ) -> RawSearchResult: """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text. Both query_vector and query_text can be provided. @@ -313,6 +334,8 @@ def get_search_results( query_vector (Optional[list[float]]): The vector embeddings to get the closest neighbors of. Defaults to None. top_k (int): The number of neighbors to return. Defaults to 5. query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None. + threshold_vector_index (float, optional): The minimum normalized score from the vector index to include in the top k search. + threshold_fulltext_index (float, optional): The minimum normalized score from the fulltext index to include in the top k search. Raises: SearchValidationError: If validation of the input arguments fail. @@ -335,6 +358,9 @@ def get_search_results( parameters["vector_index_name"] = self.vector_index_name parameters["fulltext_index_name"] = self.fulltext_index_name + parameters["threshold_vector_index"] = threshold_vector_index + parameters["threshold_fulltext_index"] = threshold_fulltext_index + if query_text and not query_vector: if not self.embedder: raise EmbeddingRequiredError( diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index c66bdf087..e3de5d724 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -203,6 +203,8 @@ def test_hybrid_search_text_happy_path( "query_text": query_text, "fulltext_index_name": fulltext_index_name, "query_vector": embed_query_vector, + "threshold_vector_index": 0.0, + "threshold_fulltext_index": 0.0, }, database_=None, routing_=neo4j.RoutingControl.READ, @@ -262,6 +264,8 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector( "query_text": query_text, "fulltext_index_name": fulltext_index_name, "query_vector": query_vector, + "threshold_vector_index": 0.0, + "threshold_fulltext_index": 0.0, }, database_=database, routing_=neo4j.RoutingControl.READ, @@ -345,6 +349,8 @@ def test_hybrid_retriever_return_properties( "query_text": query_text, "fulltext_index_name": fulltext_index_name, "query_vector": embed_query_vector, + "threshold_vector_index": 0.0, + "threshold_fulltext_index": 0.0, }, database_=None, routing_=neo4j.RoutingControl.READ, @@ -412,6 +418,8 @@ def test_hybrid_cypher_retrieval_query_with_params( "fulltext_index_name": fulltext_index_name, "query_vector": embed_query_vector, "param": "dummy-param", + "threshold_vector_index": 0.0, + "threshold_fulltext_index": 0.0, }, database_=None, routing_=neo4j.RoutingControl.READ, diff --git a/tests/unit/test_neo4j_queries.py b/tests/unit/test_neo4j_queries.py index 2359793a8..56b4e3488 100644 --- a/tests/unit/test_neo4j_queries.py +++ b/tests/unit/test_neo4j_queries.py @@ -31,22 +31,24 @@ def test_vector_search_basic() -> None: def test_hybrid_search_basic() -> None: - expected = ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / ft_index_max_score) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " - "RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score" - ) + expected = """CALL { + CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) YIELD node, score + WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score + UNWIND nodes AS n + WITH n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index + THEN (n.score / vector_index_max_score) ELSE 0 END AS score + RETURN node, score + UNION + CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) YIELD node, score + WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score + UNWIND nodes AS n + WITH n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index + THEN (n.score / ft_index_max_score) ELSE 0 END AS score + RETURN node, score +} +WITH node, max(score) AS score +ORDER BY score DESC LIMIT $top_k RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score""".strip() + result, _ = get_search_query(SearchType.HYBRID) assert result.strip() == expected.strip() @@ -123,46 +125,60 @@ def test_vector_search_with_params_from_filters(_mock: Any) -> None: def test_hybrid_search_with_retrieval_query() -> None: retrieval_query = "MATCH (n) RETURN n LIMIT 10" - expected = ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / ft_index_max_score) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " - + retrieval_query + expected = f"""CALL {{ + CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) YIELD node, score + WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score + UNWIND nodes AS n + WITH n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index + THEN (n.score / vector_index_max_score) ELSE 0 END AS score + RETURN node, score + UNION + CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {{limit: $top_k}}) YIELD node, score + WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score + UNWIND nodes AS n + WITH n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index + THEN (n.score / ft_index_max_score) ELSE 0 END AS score + RETURN node, score +}} +WITH node, max(score) AS score +ORDER BY score DESC LIMIT $top_k {retrieval_query}""".strip() + + result, _ = get_search_query( + SearchType.HYBRID, + retrieval_query=retrieval_query, ) - result, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query) assert result.strip() == expected.strip() def test_hybrid_search_with_properties() -> None: - properties = ["name", "age"] - expected = ( - "CALL { " - "CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION " - "CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) " - "YIELD node, score " - "WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score " - "UNWIND nodes AS n " - "RETURN n.node AS node, (n.score / ft_index_max_score) AS score " - "} " - "WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k " - "RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score" - ) - result, _ = get_search_query(SearchType.HYBRID, return_properties=properties) - assert result.strip() == expected.strip() + def test_hybrid_search_with_properties() -> None: + properties = ["name", "age"] + expected = """CALL { + CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) YIELD node, score + WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score + UNWIND nodes AS n + WITH n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index + THEN (n.score / vector_index_max_score) ELSE 0 END AS score + WHERE score > 0 + RETURN node, score + UNION + CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) YIELD node, score + WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score + UNWIND nodes AS n + WITH n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index + THEN (n.score / ft_index_max_score) ELSE 0 END AS score + WHERE score > 0 + RETURN node, score + } + WITH node, max(score) AS score + ORDER BY score DESC LIMIT $top_k + RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score""".strip() + + result, _ = get_search_query( + SearchType.HYBRID, + return_properties=properties, + ) + assert result.strip() == expected.strip() def test_get_query_tail_with_retrieval_query() -> None: