Skip to content

Commit 2f2cfa7

Browse files
authored
Test and unify text splitter functionality (#1547)
* add text_splitting unit test * change folder test text splitting * fix chunk fn * test new function * run formatter * run spell check * run semver * remove tiktoken mocked from tests * change progress ticker * fix ruff check
1 parent 0e7d22b commit 2f2cfa7

File tree

5 files changed

+213
-133
lines changed

5 files changed

+213
-133
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "unit tests for text_splitting"
4+
}

Diff for: graphrag/index/operations/chunk_text/strategies.py

+5-40
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111
from graphrag.config.models.chunking_config import ChunkingConfig
1212
from graphrag.index.operations.chunk_text.typing import TextChunk
13-
from graphrag.index.text_splitting.text_splitting import Tokenizer
13+
from graphrag.index.text_splitting.text_splitting import (
14+
Tokenizer,
15+
split_multiple_texts_on_tokens,
16+
)
1417
from graphrag.logger.progress import ProgressTicker
1518

1619

@@ -31,7 +34,7 @@ def encode(text: str) -> list[int]:
3134
def decode(tokens: list[int]) -> str:
3235
return enc.decode(tokens)
3336

34-
return _split_text_on_tokens(
37+
return split_multiple_texts_on_tokens(
3538
input,
3639
Tokenizer(
3740
chunk_overlap=chunk_overlap,
@@ -43,44 +46,6 @@ def decode(tokens: list[int]) -> str:
4346
)
4447

4548

46-
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
47-
# So we could have better control over the chunking process
48-
def _split_text_on_tokens(
49-
texts: list[str], enc: Tokenizer, tick: ProgressTicker
50-
) -> list[TextChunk]:
51-
"""Split incoming text and return chunks."""
52-
result = []
53-
mapped_ids = []
54-
55-
for source_doc_idx, text in enumerate(texts):
56-
encoded = enc.encode(text)
57-
tick(1)
58-
mapped_ids.append((source_doc_idx, encoded))
59-
60-
input_ids: list[tuple[int, int]] = [
61-
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
62-
]
63-
64-
start_idx = 0
65-
cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids))
66-
chunk_ids = input_ids[start_idx:cur_idx]
67-
while start_idx < len(input_ids):
68-
chunk_text = enc.decode([id for _, id in chunk_ids])
69-
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
70-
result.append(
71-
TextChunk(
72-
text_chunk=chunk_text,
73-
source_doc_indices=doc_indices,
74-
n_tokens=len(chunk_ids),
75-
)
76-
)
77-
start_idx += enc.tokens_per_chunk - enc.chunk_overlap
78-
cur_idx = min(start_idx + enc.tokens_per_chunk, len(input_ids))
79-
chunk_ids = input_ids[start_idx:cur_idx]
80-
81-
return result
82-
83-
8449
def run_sentences(
8550
input: list[str], _config: ChunkingConfig, tick: ProgressTicker
8651
) -> Iterable[TextChunk]:

Diff for: graphrag/index/text_splitting/text_splitting.py

+41-93
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,18 @@
33

44
"""A module containing the 'Tokenizer', 'TextSplitter', 'NoopTextSplitter' and 'TokenTextSplitter' models."""
55

6-
import json
76
import logging
87
from abc import ABC, abstractmethod
98
from collections.abc import Callable, Collection, Iterable
109
from dataclasses import dataclass
11-
from enum import Enum
1210
from typing import Any, Literal, cast
1311

1412
import pandas as pd
1513
import tiktoken
1614

1715
import graphrag.config.defaults as defs
18-
from graphrag.index.utils.tokens import num_tokens_from_string
16+
from graphrag.index.operations.chunk_text.typing import TextChunk
17+
from graphrag.logger.progress import ProgressTicker
1918

2019
EncodedText = list[int]
2120
DecodeFn = Callable[[EncodedText], str]
@@ -123,10 +122,10 @@ def num_tokens(self, text: str) -> int:
123122

124123
def split_text(self, text: str | list[str]) -> list[str]:
125124
"""Split text method."""
126-
if cast("bool", pd.isna(text)) or text == "":
127-
return []
128125
if isinstance(text, list):
129126
text = " ".join(text)
127+
elif cast("bool", pd.isna(text)) or text == "":
128+
return []
130129
if not isinstance(text, str):
131130
msg = f"Attempting to split a non-string value, actual is {type(text)}"
132131
raise TypeError(msg)
@@ -138,108 +137,57 @@ def split_text(self, text: str | list[str]) -> list[str]:
138137
encode=lambda text: self.encode(text),
139138
)
140139

141-
return split_text_on_tokens(text=text, tokenizer=tokenizer)
142-
143-
144-
class TextListSplitterType(str, Enum):
145-
"""Enum for the type of the TextListSplitter."""
146-
147-
DELIMITED_STRING = "delimited_string"
148-
JSON = "json"
149-
150-
151-
class TextListSplitter(TextSplitter):
152-
"""Text list splitter class definition."""
153-
154-
def __init__(
155-
self,
156-
chunk_size: int,
157-
splitter_type: TextListSplitterType = TextListSplitterType.JSON,
158-
input_delimiter: str | None = None,
159-
output_delimiter: str | None = None,
160-
model_name: str | None = None,
161-
encoding_name: str | None = None,
162-
):
163-
"""Initialize the TextListSplitter with a chunk size."""
164-
# Set the chunk overlap to 0 as we use full strings
165-
super().__init__(chunk_size, chunk_overlap=0)
166-
self._type = splitter_type
167-
self._input_delimiter = input_delimiter
168-
self._output_delimiter = output_delimiter or "\n"
169-
self._length_function = lambda x: num_tokens_from_string(
170-
x, model=model_name, encoding_name=encoding_name
171-
)
172-
173-
def split_text(self, text: str | list[str]) -> Iterable[str]:
174-
"""Split a string list into a list of strings for a given chunk size."""
175-
if not text:
176-
return []
177-
178-
result: list[str] = []
179-
current_chunk: list[str] = []
180-
181-
# Add the brackets
182-
current_length: int = self._length_function("[]")
140+
return split_single_text_on_tokens(text=text, tokenizer=tokenizer)
183141

184-
# Input should be a string list joined by a delimiter
185-
string_list = self._load_text_list(text)
186142

187-
if len(string_list) == 1:
188-
return string_list
189-
190-
for item in string_list:
191-
# Count the length of the item and add comma
192-
item_length = self._length_function(f"{item},")
143+
def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
144+
"""Split a single text and return chunks using the tokenizer."""
145+
result = []
146+
input_ids = tokenizer.encode(text)
193147

194-
if current_length + item_length > self._chunk_size:
195-
if current_chunk and len(current_chunk) > 0:
196-
# Add the current chunk to the result
197-
self._append_to_result(result, current_chunk)
148+
start_idx = 0
149+
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
150+
chunk_ids = input_ids[start_idx:cur_idx]
198151

199-
# Start a new chunk
200-
current_chunk = [item]
201-
# Add 2 for the brackets
202-
current_length = item_length
203-
else:
204-
# Add the item to the current chunk
205-
current_chunk.append(item)
206-
# Add 1 for the comma
207-
current_length += item_length
152+
while start_idx < len(input_ids):
153+
chunk_text = tokenizer.decode(list(chunk_ids))
154+
result.append(chunk_text) # Append chunked text as string
155+
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
156+
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
157+
chunk_ids = input_ids[start_idx:cur_idx]
208158

209-
# Add the last chunk to the result
210-
self._append_to_result(result, current_chunk)
159+
return result
211160

212-
return result
213161

214-
def _load_text_list(self, text: str | list[str]):
215-
"""Load the text list based on the type."""
216-
if isinstance(text, list):
217-
string_list = text
218-
elif self._type == TextListSplitterType.JSON:
219-
string_list = json.loads(text)
220-
else:
221-
string_list = text.split(self._input_delimiter)
222-
return string_list
162+
# Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
163+
# So we could have better control over the chunking process
164+
def split_multiple_texts_on_tokens(
165+
texts: list[str], tokenizer: Tokenizer, tick: ProgressTicker
166+
) -> list[TextChunk]:
167+
"""Split multiple texts and return chunks with metadata using the tokenizer."""
168+
result = []
169+
mapped_ids = []
223170

224-
def _append_to_result(self, chunk_list: list[str], new_chunk: list[str]):
225-
"""Append the current chunk to the result."""
226-
if new_chunk and len(new_chunk) > 0:
227-
if self._type == TextListSplitterType.JSON:
228-
chunk_list.append(json.dumps(new_chunk, ensure_ascii=False))
229-
else:
230-
chunk_list.append(self._output_delimiter.join(new_chunk))
171+
for source_doc_idx, text in enumerate(texts):
172+
encoded = tokenizer.encode(text)
173+
if tick:
174+
tick(1) # Track progress if tick callback is provided
175+
mapped_ids.append((source_doc_idx, encoded))
231176

177+
input_ids = [
178+
(source_doc_idx, id) for source_doc_idx, ids in mapped_ids for id in ids
179+
]
232180

233-
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
234-
"""Split incoming text and return chunks using tokenizer."""
235-
splits: list[str] = []
236-
input_ids = tokenizer.encode(text)
237181
start_idx = 0
238182
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
239183
chunk_ids = input_ids[start_idx:cur_idx]
184+
240185
while start_idx < len(input_ids):
241-
splits.append(tokenizer.decode(chunk_ids))
186+
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
187+
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
188+
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
242189
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
243190
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
244191
chunk_ids = input_ids[start_idx:cur_idx]
245-
return splits
192+
193+
return result

Diff for: tests/unit/indexing/text_splitting/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License

0 commit comments

Comments
 (0)