Skip to content

Commit

Permalink
test: improve reranker test
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber committed Dec 4, 2024
1 parent fa15e3f commit 3b474c9
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions tests/test_rerank.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test RAGLite's reranking functionality."""

import random
from typing import TypeVar

import pytest
Expand All @@ -15,8 +16,8 @@

def kendall_tau(a: list[T], b: list[T]) -> float:
"""Measure the Kendall rank correlation coefficient between two lists."""
tau: float = kendalltau(range(len(a)), [a.index(el) for el in b])[0]
return tau
τ: float = kendalltau(range(len(a)), [a.index(el) for el in b])[0] # noqa: PLC2401
return τ


@pytest.fixture(
Expand Down Expand Up @@ -56,11 +57,14 @@ def test_reranker(
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
assert all(isinstance(chunk, Chunk) for chunk in chunks)
assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
# Rerank the chunks given an inverted chunk order.
reranked_chunks = rerank_chunks(query, chunks[::-1], config=raglite_test_config)
if reranker:
assert kendall_tau(chunks, reranked_chunks) >= kendall_tau(chunks[::-1], reranked_chunks)
# Test that we can also rerank given the chunk_ids only.
reranked_chunks = rerank_chunks(query, chunk_ids[::-1], config=raglite_test_config)
if reranker:
assert kendall_tau(chunks, reranked_chunks) >= kendall_tau(chunks[::-1], reranked_chunks)
# Randomly shuffle the chunks.
random.seed(42)
chunks_random = random.sample(chunks, len(chunks))
# Rerank the chunks starting from a pathological order and verify that it improves the ranking.
for arg in (chunks[::-1], chunk_ids[::-1]):
reranked_chunks = rerank_chunks(query, arg, config=raglite_test_config)
if reranker:
τ_search = kendall_tau(chunks, reranked_chunks) # noqa: PLC2401
τ_inverse = kendall_tau(chunks[::-1], reranked_chunks) # noqa: PLC2401
τ_random = kendall_tau(chunks_random, reranked_chunks) # noqa: PLC2401
assert τ_search >= τ_random >= τ_inverse

0 comments on commit 3b474c9

Please sign in to comment.