Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow thresholding on vector and fulltext indexes for Hybrid retrievers #239

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
- Support for conversations with message history, including a new `message_history` parameter for LLM interactions.
- Ability to include system instructions in LLM invoke method.
- Summarization of chat history to enhance query embedding and context handling in GraphRAG.
- Support for thresholding on vector and fulltext indexes in Hybrid retrievers, enabling users to set importance levels for search results.

### Changed
- Updated LLM implementations to handle message history consistently across providers.
Expand Down
9 changes: 8 additions & 1 deletion examples/retrieve/hybrid_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@
# Perform the similarity search for a text query
# (retrieve the top 5 most similar nodes)
query_text = "Find me a movie about aliens"
print(retriever.search(query_text=query_text, top_k=5))
results = retriever.search(
query_text=query_text,
top_k=5,
threshold_vector_index=0.1,
threshold_fulltext_index=0.8,
)

print(results.items[0].metadata)

# note: it is also possible to query from a query_vector directly:
# query_vector: list[float] = [...]
Expand Down
70 changes: 42 additions & 28 deletions src/neo4j_graphrag/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,33 +116,50 @@
)


def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str:
def _get_hybrid_query(
neo4j_version_is_5_23_or_above: bool,
threshold_vector_index: float = 0.0,
threshold_fulltext_index: float = 0.0,
) -> str:
vector_where_clause = "WHERE score > 0" if threshold_vector_index > 0 else ""
fulltext_where_clause = "WHERE score > 0" if threshold_fulltext_index > 0 else ""

if neo4j_version_is_5_23_or_above:
return (
f"CALL () {{ {VECTOR_INDEX_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
f"UNION "
f"{FULL_TEXT_SEARCH_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
)
return f"""CALL () {{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this to a multi-line string here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helps me modify the unit tests more easily without worrying about correct indentation and spaces

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With multi-line strings, you do have to worry about correct indentation though. With the

(
"Hello "
"world"
)

approach you only have to worry about there being a space at the end of every line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I still generally prefer multi-line strings as I find them more readable and tend to avoid mistakes after changes. Do you think we should revert this back?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m inclined to prefer the older approach, but I’m also open to this option. @stellasia what are your thoughts?

{VECTOR_INDEX_QUERY}
WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score
UNWIND nodes AS n
WITH n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index
THEN (n.score / vector_index_max_score) ELSE 0 END AS score {vector_where_clause}
RETURN node, score
UNION
{FULL_TEXT_SEARCH_QUERY}
WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score
UNWIND nodes AS n
WITH n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index
THEN (n.score / ft_index_max_score) ELSE 0 END AS score {fulltext_where_clause}
RETURN node, score
}}
WITH node, max(score) AS score
ORDER BY score DESC LIMIT $top_k"""
else:
return (
f"CALL {{ {VECTOR_INDEX_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
f"UNION "
f"{FULL_TEXT_SEARCH_QUERY} "
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
f"UNWIND nodes AS n "
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
)
return f"""CALL {{
{VECTOR_INDEX_QUERY}
WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score
UNWIND nodes AS n
WITH n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index
THEN (n.score / vector_index_max_score) ELSE 0 END AS score{vector_where_clause}
RETURN node, score
UNION
{FULL_TEXT_SEARCH_QUERY}
WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score
UNWIND nodes AS n
WITH n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index
THEN (n.score / ft_index_max_score) ELSE 0 END AS score{fulltext_where_clause}
RETURN node, score
}}
WITH node, max(score) AS score
ORDER BY score DESC LIMIT $top_k"""


def _get_filtered_vector_query(
Expand Down Expand Up @@ -186,7 +203,6 @@ def get_search_query(
neo4j_version_is_5_23_or_above: bool = False,
) -> tuple[str, dict[str, Any]]:
"""Build the search query, including pre-filtering if needed, and return clause.

Args
search_type: Search type we want to search for:
return_properties (list[str]): list of property names to return.
Expand All @@ -197,10 +213,8 @@ def get_search_query(
embedding_node_property (str): the name of the property holding the embeddings
embedding_dimension (int): the dimension of the embeddings
filters (dict[str, Any]): filters used to pre-filter the nodes before vector search

Returns:
tuple[str, dict[str, Any]]: query and parameters

"""
warnings.warn(
"The default returned 'id' field in the search results will be removed. Please switch to using 'elementId' instead.",
Expand Down
26 changes: 26 additions & 0 deletions src/neo4j_graphrag/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ class HybridRetriever(Retriever):
"""
Provides retrieval method using combination of vector search over embeddings and
fulltext search.

The retriever uses both the vector and fulltext indexes. It uses the user query to
search both indexes, retrieving nodes and their corresponding scores.
After normalizing the scores from each set of results, it merges them, ranks the
combined results by score, and returns the top matches.

If an embedder is provided, it needs to have the required Embedder type.

Example:
Expand Down Expand Up @@ -141,6 +147,8 @@ def get_search_results(
query_text: str,
query_vector: Optional[list[float]] = None,
top_k: int = 5,
threshold_vector_index: float = 0.0,
threshold_fulltext_index: float = 0.0,
) -> RawSearchResult:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
Expand All @@ -159,6 +167,8 @@ def get_search_results(
query_text (str): The text to get the closest neighbors of.
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
top_k (int, optional): The number of neighbors to return. Defaults to 5.
threshold_vector_index (float, optional): The minimum normalized score from the vector index to include in the top k search.
threshold_fulltext_index (float, optional): The minimum normalized score from the fulltext index to include in the top k search.

Raises:
SearchValidationError: If validation of the input arguments fail.
Expand All @@ -180,6 +190,9 @@ def get_search_results(
parameters["vector_index_name"] = self.vector_index_name
parameters["fulltext_index_name"] = self.fulltext_index_name

parameters["threshold_vector_index"] = threshold_vector_index
parameters["threshold_fulltext_index"] = threshold_fulltext_index

if query_text and not query_vector:
if not self.embedder:
raise EmbeddingRequiredError(
Expand Down Expand Up @@ -216,6 +229,12 @@ class HybridCypherRetriever(Retriever):
Provides retrieval method using combination of vector search over embeddings and
fulltext search, augmented by a Cypher query.
This retriever builds on HybridRetriever.

The retriever uses both the vector and fulltext indexes. It uses the user query to
search both indexes, retrieving nodes and their corresponding scores.
After normalizing the scores from each set of results, it merges them, ranks the
combined results by score, and returns the top matches.

If an embedder is provided, it needs to have the required Embedder type.

Note: `node` is a variable from the base query that can be used in `retrieval_query` as seen in the example below.
Expand Down Expand Up @@ -296,6 +315,8 @@ def get_search_results(
query_vector: Optional[list[float]] = None,
top_k: int = 5,
query_params: Optional[dict[str, Any]] = None,
threshold_vector_index: float = 0.0,
threshold_fulltext_index: float = 0.0,
) -> RawSearchResult:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
Expand All @@ -313,6 +334,8 @@ def get_search_results(
query_vector (Optional[list[float]]): The vector embeddings to get the closest neighbors of. Defaults to None.
top_k (int): The number of neighbors to return. Defaults to 5.
query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None.
threshold_vector_index (float, optional): The minimum normalized score from the vector index to include in the top k search.
threshold_fulltext_index (float, optional): The minimum normalized score from the fulltext index to include in the top k search.

Raises:
SearchValidationError: If validation of the input arguments fail.
Expand All @@ -335,6 +358,9 @@ def get_search_results(
parameters["vector_index_name"] = self.vector_index_name
parameters["fulltext_index_name"] = self.fulltext_index_name

parameters["threshold_vector_index"] = threshold_vector_index
parameters["threshold_fulltext_index"] = threshold_fulltext_index

if query_text and not query_vector:
if not self.embedder:
raise EmbeddingRequiredError(
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/retrievers/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def test_hybrid_search_text_happy_path(
"query_text": query_text,
"fulltext_index_name": fulltext_index_name,
"query_vector": embed_query_vector,
"threshold_vector_index": 0.0,
"threshold_fulltext_index": 0.0,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
Expand Down Expand Up @@ -262,6 +264,8 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector(
"query_text": query_text,
"fulltext_index_name": fulltext_index_name,
"query_vector": query_vector,
"threshold_vector_index": 0.0,
"threshold_fulltext_index": 0.0,
},
database_=database,
routing_=neo4j.RoutingControl.READ,
Expand Down Expand Up @@ -345,6 +349,8 @@ def test_hybrid_retriever_return_properties(
"query_text": query_text,
"fulltext_index_name": fulltext_index_name,
"query_vector": embed_query_vector,
"threshold_vector_index": 0.0,
"threshold_fulltext_index": 0.0,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
Expand Down Expand Up @@ -412,6 +418,8 @@ def test_hybrid_cypher_retrieval_query_with_params(
"fulltext_index_name": fulltext_index_name,
"query_vector": embed_query_vector,
"param": "dummy-param",
"threshold_vector_index": 0.0,
"threshold_fulltext_index": 0.0,
},
database_=None,
routing_=neo4j.RoutingControl.READ,
Expand Down
118 changes: 67 additions & 51 deletions tests/unit/test_neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,24 @@ def test_vector_search_basic() -> None:


def test_hybrid_search_basic() -> None:
expected = (
"CALL { "
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
)
expected = """CALL {
CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) YIELD node, score
WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score
UNWIND nodes AS n
WITH n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index
THEN (n.score / vector_index_max_score) ELSE 0 END AS score
RETURN node, score
UNION
CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) YIELD node, score
WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score
UNWIND nodes AS n
WITH n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index
THEN (n.score / ft_index_max_score) ELSE 0 END AS score
RETURN node, score
}
WITH node, max(score) AS score
ORDER BY score DESC LIMIT $top_k RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score""".strip()

result, _ = get_search_query(SearchType.HYBRID)
assert result.strip() == expected.strip()

Expand Down Expand Up @@ -123,46 +125,60 @@ def test_vector_search_with_params_from_filters(_mock: Any) -> None:

def test_hybrid_search_with_retrieval_query() -> None:
retrieval_query = "MATCH (n) RETURN n LIMIT 10"
expected = (
"CALL { "
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
+ retrieval_query
expected = f"""CALL {{
CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) YIELD node, score
WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score
UNWIND nodes AS n
WITH n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index
THEN (n.score / vector_index_max_score) ELSE 0 END AS score
RETURN node, score
UNION
CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {{limit: $top_k}}) YIELD node, score
WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score
UNWIND nodes AS n
WITH n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index
THEN (n.score / ft_index_max_score) ELSE 0 END AS score
RETURN node, score
}}
WITH node, max(score) AS score
ORDER BY score DESC LIMIT $top_k {retrieval_query}""".strip()

result, _ = get_search_query(
SearchType.HYBRID,
retrieval_query=retrieval_query,
)
result, _ = get_search_query(SearchType.HYBRID, retrieval_query=retrieval_query)
assert result.strip() == expected.strip()


def test_hybrid_search_with_properties() -> None:
properties = ["name", "age"]
expected = (
"CALL { "
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / vector_index_max_score) AS score UNION "
"CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) "
"YIELD node, score "
"WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score "
"UNWIND nodes AS n "
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score "
"} "
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
"RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score"
)
result, _ = get_search_query(SearchType.HYBRID, return_properties=properties)
assert result.strip() == expected.strip()
def test_hybrid_search_with_properties() -> None:
properties = ["name", "age"]
expected = """CALL {
CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) YIELD node, score
WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score
UNWIND nodes AS n
WITH n.node AS node, CASE WHEN (n.score / vector_index_max_score) >= $threshold_vector_index
THEN (n.score / vector_index_max_score) ELSE 0 END AS score
WHERE score > 0
RETURN node, score
UNION
CALL db.index.fulltext.queryNodes($fulltext_index_name, $query_text, {limit: $top_k}) YIELD node, score
WITH collect({node:node, score:score}) AS nodes, max(score) AS ft_index_max_score
UNWIND nodes AS n
WITH n.node AS node, CASE WHEN (n.score / ft_index_max_score) >= $threshold_fulltext_index
THEN (n.score / ft_index_max_score) ELSE 0 END AS score
WHERE score > 0
RETURN node, score
}
WITH node, max(score) AS score
ORDER BY score DESC LIMIT $top_k
RETURN node {.name, .age} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, elementId(node) AS id, score""".strip()

result, _ = get_search_query(
SearchType.HYBRID,
return_properties=properties,
)
assert result.strip() == expected.strip()


def test_get_query_tail_with_retrieval_query() -> None:
Expand Down
Loading