Skip to content

Commit

Permalink
fix: fix optimal chunking edge cases (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Oct 15, 2024
1 parent a593ae9 commit 996b9ee
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ RAGLite is a Python package for Retrieval-Augmented Generation (RAG) with Postgr

- 🧠 Choose any LLM provider with [LiteLLM](https://github.com/BerriAI/litellm), including local [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) models
- 💾 Choose either [PostgreSQL](https://github.com/postgres/postgres) or [SQLite](https://github.com/sqlite/sqlite) as a keyword & vector search database
- 🥇 Choose any reranker with [rerankers](https://github.com/AnswerDotAI/rerankers), including multi-lingual [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) as the default
- 🥇 Choose any reranker with [rerankers](https://github.com/AnswerDotAI/rerankers), including multilingual [FlashRank](https://github.com/PrithivirajDamodaran/FlashRank) as the default

##### Fast and permissive

Expand Down
37 changes: 25 additions & 12 deletions src/raglite/_split_chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,39 @@
from raglite._typing import FloatMatrix


def split_chunks(
def split_chunks( # noqa: C901, PLR0915
sentences: list[str],
sentence_embeddings: FloatMatrix,
sentence_window_size: int = 3,
max_size: int = 1440,
) -> tuple[list[str], list[FloatMatrix]]:
"""Split sentences into optimal semantic chunks with corresponding sentence embeddings."""
# Validate the input.
sentence_length = np.asarray([len(sentence) for sentence in sentences])
if not np.all(sentence_length <= max_size):
error_message = "Sentence with length larger than chunk max_size detected."
raise ValueError(error_message)
if not np.all(np.linalg.norm(sentence_embeddings, axis=1) > 0.0):
error_message = "Sentence embeddings with zero norm detected."
raise ValueError(error_message)
# Exit early if there is only one chunk to return.
if len(sentences) <= 1 or sum(sentence_length) <= max_size:
return ["".join(sentences)] if sentences else sentences, [sentence_embeddings]
# Normalise the sentence embeddings to unit norm.
X = sentence_embeddings.astype(np.float32) # noqa: N806
X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806
# Select nonoutlying sentences and remove the discourse vector.
sentence_length = np.asarray([len(sentence) for sentence in sentences])
q15, q85 = np.quantile(sentence_length, [0.15, 0.85])
outlying_sentences = (sentence_length <= q15) | (q85 <= sentence_length)
discourse = np.mean(X[~outlying_sentences, :], axis=0)
X = X - np.outer(X @ discourse, discourse) # noqa: N806
# Renormalise to unit norm.
X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806
nonoutlying_sentences = (q15 <= sentence_length) & (sentence_length <= q85)
discourse = np.mean(X[nonoutlying_sentences, :], axis=0)
discourse = discourse / np.linalg.norm(discourse)
if not np.any(np.linalg.norm(X - discourse[np.newaxis, :], axis=1) <= np.finfo(X.dtype).eps):
X = X - np.outer(X @ discourse, discourse) # noqa: N806
X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806
# For each partition point in the list of sentences, compute the similarity of the windows
# before and after the partition point.
# before and after the partition point. Sentence embeddings are assumed to be of the sentence
# itself and at most the (sentence_window_size - 1) sentences that preceed it.
sentence_window_size = min(len(sentences) - 1, sentence_window_size)
windows_before = X[:-sentence_window_size]
windows_after = X[sentence_window_size:]
partition_similarity = np.ones(len(sentences) - 1, dtype=X.dtype)
Expand Down Expand Up @@ -56,9 +69,7 @@ def split_chunks(
for i in range(len(sentences) - 1):
r = sentence_length_cumsum[i - 1] if i > 0 else 0
idx = np.searchsorted(sentence_length_cumsum - r, max_size)
if idx == i:
error_message = "Sentence with length larger than chunk max_size detected."
raise ValueError(error_message)
assert idx > i
if idx == len(sentence_length_cumsum):
break
cols = list(range(i, idx))
Expand All @@ -78,7 +89,9 @@ def split_chunks(
bounds=(0, 1),
integrality=[1] * A.shape[1],
)
assert res.success, "Optimization of chunk partitions failed."
if not res.success:
error_message = "Optimization of chunk partitions failed."
raise ValueError(error_message)
# Split the sentences and their window embeddings into optimal chunks.
partition_indices = (np.where(res.x)[0] + 1).tolist()
chunks = [
Expand Down
56 changes: 56 additions & 0 deletions tests/test_split_chunks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Test RAGLite's chunk splitting functionality."""

import numpy as np
import pytest

from raglite._split_chunks import split_chunks


@pytest.mark.parametrize(
"sentences",
[
pytest.param([], id="one_chunk:no_sentences"),
pytest.param(["Hello world"], id="one_chunk:one_sentence"),
pytest.param(["Hello world"] * 2, id="one_chunk:two_sentences"),
pytest.param(["Hello world"] * 3, id="one_chunk:three_sentences"),
pytest.param(["Hello world"] * 100, id="one_chunk:many_sentences"),
pytest.param(["Hello world", "X" * 1000], id="n_chunks:two_sentences_a"),
pytest.param(["X" * 1000, "Hello world"], id="n_chunks:two_sentences_b"),
pytest.param(["Hello world", "X" * 1000, "X" * 1000], id="n_chunks:three_sentences_a"),
pytest.param(["X" * 1000, "Hello world", "X" * 1000], id="n_chunks:three_sentences_b"),
pytest.param(["X" * 1000, "X" * 1000, "Hello world"], id="n_chunks:three_sentences_c"),
pytest.param(["X" * 1000] * 100, id="n_chunks:many_sentences_a"),
pytest.param(["X" * 100] * 1000, id="n_chunks:many_sentences_b"),
],
)
def test_edge_cases(sentences: list[str]) -> None:
"""Test chunk splitting edge cases."""
sentence_embeddings = np.ones((len(sentences), 768)).astype(np.float16)
chunks, chunk_embeddings = split_chunks(
sentences, sentence_embeddings, sentence_window_size=3, max_size=1440
)
assert isinstance(chunks, list)
assert isinstance(chunk_embeddings, list)
assert len(chunk_embeddings) == (len(chunks) if sentences else 1)
assert all(isinstance(chunk, str) for chunk in chunks)
assert all(isinstance(chunk_embedding, np.ndarray) for chunk_embedding in chunk_embeddings)
assert all(ce.dtype == sentence_embeddings.dtype for ce in chunk_embeddings)
assert sum(ce.shape[0] for ce in chunk_embeddings) == sentence_embeddings.shape[0]
assert all(ce.shape[1] == sentence_embeddings.shape[1] for ce in chunk_embeddings)


@pytest.mark.parametrize(
"sentences",
[
pytest.param(["Hello world" * 1000] + ["X"] * 100, id="first"),
pytest.param(["X"] * 50 + ["Hello world" * 1000] + ["X"] * 50, id="middle"),
pytest.param(["X"] * 100 + ["Hello world" * 1000], id="last"),
],
)
def test_long_sentence(sentences: list[str]) -> None:
"""Test chunking on sentences that are too long."""
sentence_embeddings = np.ones((len(sentences), 768)).astype(np.float16)
with pytest.raises(
ValueError, match="Sentence with length larger than chunk max_size detected."
):
_ = split_chunks(sentences, sentence_embeddings, sentence_window_size=3, max_size=1440)

0 comments on commit 996b9ee

Please sign in to comment.