diff --git a/src/genai_utils/sentence_linking.py b/src/genai_utils/sentence_linking.py index a3bb145..81d24ad 100644 --- a/src/genai_utils/sentence_linking.py +++ b/src/genai_utils/sentence_linking.py @@ -4,9 +4,13 @@ from genai_utils.models import Span +DEFAULT_SIMILARITY_THRESHOLD = 80 + def find_quote_in_sentence( - sentence: str, quote: str, threshold: float = 80 + sentence: str, + quote: str, + threshold: float = DEFAULT_SIMILARITY_THRESHOLD, ) -> Span | None: """ Returns a span for the quote within the sentence if there's a match. @@ -17,6 +21,9 @@ def find_quote_in_sentence( The sentence to search inside quote: str The string to find inside the sentence + threshold: float + The threshold over which texts are considered to match. + It's an overlap percentage out of 100. Returns ------- @@ -36,7 +43,9 @@ def find_quote_in_sentence( def get_best_matching_sentence_for_quote( - original_quote: str, sentences: list[str] + original_quote: str, + sentences: list[str], + threshold: float = DEFAULT_SIMILARITY_THRESHOLD, ) -> tuple[int, Span, float] | None: """ Finds the best matching sentence for the given quote out of the provided sentences. @@ -48,6 +57,9 @@ def get_best_matching_sentence_for_quote( In most use cases, this would be the quote made by the Gen AI model. sentences: list[str] A list of sentences to search. + threshold: float + The threshold over which texts are considered to match. + It's an overlap percentage out of 100. Returns ------- @@ -61,7 +73,8 @@ def get_best_matching_sentence_for_quote( sentences_matching_quote = [ (sentence_idx, span, fuzz.ratio(original_quote, sentence)) for sentence_idx, sentence in enumerate(sentences) - if (span := find_quote_in_sentence(sentence, original_quote)) is not None + if (span := find_quote_in_sentence(sentence, original_quote, threshold)) + is not None ] # return None if there's no matching sentences @@ -75,7 +88,9 @@ def get_best_matching_sentence_for_quote( def link_quotes_and_sentences( - quotes: list[str], sentences: list[str] + quotes: list[str], + sentences: list[str], + threshold: float = DEFAULT_SIMILARITY_THRESHOLD, ) -> list[tuple[int, int, Span]]: """ Links pairs of matching quotes and sentences. @@ -88,6 +103,9 @@ def link_quotes_and_sentences( A list of quotes to find matches for. sentences: list[str] A list of sentences to search against each quote. + threshold: float + The threshold over which texts are considered to match. + It's an overlap percentage out of 100. Returns ------- @@ -101,6 +119,10 @@ def link_quotes_and_sentences( return [ (quote_idx, best_sentence[0], best_sentence[1]) for quote_idx, quote in enumerate(quotes) - if (best_sentence := get_best_matching_sentence_for_quote(quote, sentences)) + if ( + best_sentence := get_best_matching_sentence_for_quote( + quote, sentences, threshold + ) + ) is not None ] diff --git a/tests/genai_utils/test_sentence_linking.py b/tests/genai_utils/test_sentence_linking.py index 5e2f7a8..bda107c 100644 --- a/tests/genai_utils/test_sentence_linking.py +++ b/tests/genai_utils/test_sentence_linking.py @@ -319,3 +319,60 @@ def test_link_quotes_and_sentences( ) -> None: claimants_per_sentence = link_quotes_and_sentences(quotes, sentences) assert claimants_per_sentence == expected + + +@mark.parametrize( + "quotes,sentences,threshold,passes", + [ + param( + [ + "I will not seek reelection", + "I won't seek reelection", + "No reelection for me", + ], + [ + "I will not seek reelection", + ], + 70, + [True, True, True], + id="low threshold", + ), + param( + [ + "I will not seek reelection", + "I won't seek reelection", + "No reelection for me", + ], + [ + "I will not seek reelection", + ], + 80, + [True, True, False], + id="mid threshold", + ), + param( + [ + "I will not seek reelection", + "I won't seek reelection", + "No reelection for me", + ], + [ + "I will not seek reelection", + ], + 100, + [True, False, False], + id="high threshold", + ), + ], +) +def test_threshold_is_passed_along( + quotes: list[str], sentences: list[str], threshold: float, passes: list[bool] +): + linked = link_quotes_and_sentences(quotes, sentences, threshold) + + quotes_matched = [m[0] for m in linked] + for i, should_pass in enumerate(passes): + if should_pass: + assert i in quotes_matched + else: + assert i not in quotes_matched