diff --git a/qdrant_client/hybrid/fusion.py b/qdrant_client/hybrid/fusion.py index 025df256..015f4a73 100644 --- a/qdrant_client/hybrid/fusion.py +++ b/qdrant_client/hybrid/fusion.py @@ -33,6 +33,9 @@ def distribution_based_score_fusion( responses: list[list[models.ScoredPoint]], limit: int ) -> list[models.ScoredPoint]: def normalize(response: list[models.ScoredPoint]) -> list[models.ScoredPoint]: + if len(response) <= 1: + return response + total = sum([point.score for point in response]) mean = total / len(response) variance = sum([(point.score - mean) ** 2 for point in response]) / (len(response) - 1) diff --git a/qdrant_client/hybrid/test_reranking.py b/qdrant_client/hybrid/test_reranking.py index a7d89743..fa9aa05c 100644 --- a/qdrant_client/hybrid/test_reranking.py +++ b/qdrant_client/hybrid/test_reranking.py @@ -50,6 +50,9 @@ def test_distribution_based_score_fusion() -> None: assert fused[1].id == 0 assert fused[2].id == 4 + fused = distribution_based_score_fusion([[responses[0][0]]], limit=3) + assert fused[0].id == 1 + def test_reciprocal_rank_fusion_empty_responses() -> None: responses: list[list[models.ScoredPoint]] = [[]]