Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions src/genai_utils/sentence_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

from genai_utils.models import Span

DEFAULT_SIMILARITY_THRESHOLD = 80


def find_quote_in_sentence(
sentence: str, quote: str, threshold: float = 80
sentence: str,
quote: str,
threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
) -> Span | None:
"""
Returns a span for the quote within the sentence if there's a match.
Expand All @@ -17,6 +21,9 @@ def find_quote_in_sentence(
The sentence to search inside
quote: str
The string to find inside the sentence
threshold: float
The threshold over which texts are considered to match.
It's an overlap percentage out of 100.

Returns
-------
Expand All @@ -36,7 +43,9 @@ def find_quote_in_sentence(


def get_best_matching_sentence_for_quote(
original_quote: str, sentences: list[str]
original_quote: str,
sentences: list[str],
threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
) -> tuple[int, Span, float] | None:
"""
Finds the best matching sentence for the given quote out of the provided sentences.
Expand All @@ -48,6 +57,9 @@ def get_best_matching_sentence_for_quote(
In most use cases, this would be the quote made by the Gen AI model.
sentences: list[str]
A list of sentences to search.
threshold: float
The threshold over which texts are considered to match.
It's an overlap percentage out of 100.

Returns
-------
Expand All @@ -61,7 +73,8 @@ def get_best_matching_sentence_for_quote(
sentences_matching_quote = [
(sentence_idx, span, fuzz.ratio(original_quote, sentence))
for sentence_idx, sentence in enumerate(sentences)
if (span := find_quote_in_sentence(sentence, original_quote)) is not None
if (span := find_quote_in_sentence(sentence, original_quote, threshold))
is not None
]

# return None if there's no matching sentences
Expand All @@ -75,7 +88,9 @@ def get_best_matching_sentence_for_quote(


def link_quotes_and_sentences(
quotes: list[str], sentences: list[str]
quotes: list[str],
sentences: list[str],
threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
) -> list[tuple[int, int, Span]]:
"""
Links pairs of matching quotes and sentences.
Expand All @@ -88,6 +103,9 @@ def link_quotes_and_sentences(
A list of quotes to find matches for.
sentences: list[str]
A list of sentences to search against each quote.
threshold: float
The threshold over which texts are considered to match.
It's an overlap percentage out of 100.

Returns
-------
Expand All @@ -101,6 +119,10 @@ def link_quotes_and_sentences(
return [
(quote_idx, best_sentence[0], best_sentence[1])
for quote_idx, quote in enumerate(quotes)
if (best_sentence := get_best_matching_sentence_for_quote(quote, sentences))
if (
best_sentence := get_best_matching_sentence_for_quote(
quote, sentences, threshold
)
)
is not None
]
57 changes: 57 additions & 0 deletions tests/genai_utils/test_sentence_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,60 @@ def test_link_quotes_and_sentences(
) -> None:
claimants_per_sentence = link_quotes_and_sentences(quotes, sentences)
assert claimants_per_sentence == expected


@mark.parametrize(
"quotes,sentences,threshold,passes",
[
param(
[
"I will not seek reelection",
"I won't seek reelection",
"No reelection for me",
],
[
"I will not seek reelection",
],
70,
[True, True, True],
id="low threshold",
),
param(
[
"I will not seek reelection",
"I won't seek reelection",
"No reelection for me",
],
[
"I will not seek reelection",
],
80,
[True, True, False],
id="mid threshold",
),
param(
[
"I will not seek reelection",
"I won't seek reelection",
"No reelection for me",
],
[
"I will not seek reelection",
],
100,
[True, False, False],
id="high threshold",
),
],
)
def test_threshold_is_passed_along(
quotes: list[str], sentences: list[str], threshold: float, passes: list[bool]
):
linked = link_quotes_and_sentences(quotes, sentences, threshold)

quotes_matched = [m[0] for m in linked]
for i, should_pass in enumerate(passes):
if should_pass:
assert i in quotes_matched
else:
assert i not in quotes_matched