3
3
4
4
"""A module containing the 'Tokenizer', 'TextSplitter', 'NoopTextSplitter' and 'TokenTextSplitter' models."""
5
5
6
- import json
7
6
import logging
8
7
from abc import ABC , abstractmethod
9
8
from collections .abc import Callable , Collection , Iterable
10
9
from dataclasses import dataclass
11
- from enum import Enum
12
10
from typing import Any , Literal , cast
13
11
14
12
import pandas as pd
15
13
import tiktoken
16
14
17
15
import graphrag .config .defaults as defs
18
- from graphrag .index .utils .tokens import num_tokens_from_string
16
+ from graphrag .index .operations .chunk_text .typing import TextChunk
17
+ from graphrag .logger .progress import ProgressTicker
19
18
20
19
EncodedText = list [int ]
21
20
DecodeFn = Callable [[EncodedText ], str ]
@@ -123,10 +122,10 @@ def num_tokens(self, text: str) -> int:
123
122
124
123
def split_text (self , text : str | list [str ]) -> list [str ]:
125
124
"""Split text method."""
126
- if cast ("bool" , pd .isna (text )) or text == "" :
127
- return []
128
125
if isinstance (text , list ):
129
126
text = " " .join (text )
127
+ elif cast ("bool" , pd .isna (text )) or text == "" :
128
+ return []
130
129
if not isinstance (text , str ):
131
130
msg = f"Attempting to split a non-string value, actual is { type (text )} "
132
131
raise TypeError (msg )
@@ -138,108 +137,57 @@ def split_text(self, text: str | list[str]) -> list[str]:
138
137
encode = lambda text : self .encode (text ),
139
138
)
140
139
141
- return split_text_on_tokens (text = text , tokenizer = tokenizer )
142
-
143
-
144
- class TextListSplitterType (str , Enum ):
145
- """Enum for the type of the TextListSplitter."""
146
-
147
- DELIMITED_STRING = "delimited_string"
148
- JSON = "json"
149
-
150
-
151
- class TextListSplitter (TextSplitter ):
152
- """Text list splitter class definition."""
153
-
154
- def __init__ (
155
- self ,
156
- chunk_size : int ,
157
- splitter_type : TextListSplitterType = TextListSplitterType .JSON ,
158
- input_delimiter : str | None = None ,
159
- output_delimiter : str | None = None ,
160
- model_name : str | None = None ,
161
- encoding_name : str | None = None ,
162
- ):
163
- """Initialize the TextListSplitter with a chunk size."""
164
- # Set the chunk overlap to 0 as we use full strings
165
- super ().__init__ (chunk_size , chunk_overlap = 0 )
166
- self ._type = splitter_type
167
- self ._input_delimiter = input_delimiter
168
- self ._output_delimiter = output_delimiter or "\n "
169
- self ._length_function = lambda x : num_tokens_from_string (
170
- x , model = model_name , encoding_name = encoding_name
171
- )
172
-
173
- def split_text (self , text : str | list [str ]) -> Iterable [str ]:
174
- """Split a string list into a list of strings for a given chunk size."""
175
- if not text :
176
- return []
177
-
178
- result : list [str ] = []
179
- current_chunk : list [str ] = []
180
-
181
- # Add the brackets
182
- current_length : int = self ._length_function ("[]" )
140
+ return split_single_text_on_tokens (text = text , tokenizer = tokenizer )
183
141
184
- # Input should be a string list joined by a delimiter
185
- string_list = self ._load_text_list (text )
186
142
187
- if len (string_list ) == 1 :
188
- return string_list
189
-
190
- for item in string_list :
191
- # Count the length of the item and add comma
192
- item_length = self ._length_function (f"{ item } ," )
143
+ def split_single_text_on_tokens (text : str , tokenizer : Tokenizer ) -> list [str ]:
144
+ """Split a single text and return chunks using the tokenizer."""
145
+ result = []
146
+ input_ids = tokenizer .encode (text )
193
147
194
- if current_length + item_length > self ._chunk_size :
195
- if current_chunk and len (current_chunk ) > 0 :
196
- # Add the current chunk to the result
197
- self ._append_to_result (result , current_chunk )
148
+ start_idx = 0
149
+ cur_idx = min (start_idx + tokenizer .tokens_per_chunk , len (input_ids ))
150
+ chunk_ids = input_ids [start_idx :cur_idx ]
198
151
199
- # Start a new chunk
200
- current_chunk = [item ]
201
- # Add 2 for the brackets
202
- current_length = item_length
203
- else :
204
- # Add the item to the current chunk
205
- current_chunk .append (item )
206
- # Add 1 for the comma
207
- current_length += item_length
152
+ while start_idx < len (input_ids ):
153
+ chunk_text = tokenizer .decode (list (chunk_ids ))
154
+ result .append (chunk_text ) # Append chunked text as string
155
+ start_idx += tokenizer .tokens_per_chunk - tokenizer .chunk_overlap
156
+ cur_idx = min (start_idx + tokenizer .tokens_per_chunk , len (input_ids ))
157
+ chunk_ids = input_ids [start_idx :cur_idx ]
208
158
209
- # Add the last chunk to the result
210
- self ._append_to_result (result , current_chunk )
159
+ return result
211
160
212
- return result
213
161
214
- def _load_text_list (self , text : str | list [str ]):
215
- """Load the text list based on the type."""
216
- if isinstance (text , list ):
217
- string_list = text
218
- elif self ._type == TextListSplitterType .JSON :
219
- string_list = json .loads (text )
220
- else :
221
- string_list = text .split (self ._input_delimiter )
222
- return string_list
162
+ # Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
163
+ # So we could have better control over the chunking process
164
+ def split_multiple_texts_on_tokens (
165
+ texts : list [str ], tokenizer : Tokenizer , tick : ProgressTicker
166
+ ) -> list [TextChunk ]:
167
+ """Split multiple texts and return chunks with metadata using the tokenizer."""
168
+ result = []
169
+ mapped_ids = []
223
170
224
- def _append_to_result (self , chunk_list : list [str ], new_chunk : list [str ]):
225
- """Append the current chunk to the result."""
226
- if new_chunk and len (new_chunk ) > 0 :
227
- if self ._type == TextListSplitterType .JSON :
228
- chunk_list .append (json .dumps (new_chunk , ensure_ascii = False ))
229
- else :
230
- chunk_list .append (self ._output_delimiter .join (new_chunk ))
171
+ for source_doc_idx , text in enumerate (texts ):
172
+ encoded = tokenizer .encode (text )
173
+ if tick :
174
+ tick (1 ) # Track progress if tick callback is provided
175
+ mapped_ids .append ((source_doc_idx , encoded ))
231
176
177
+ input_ids = [
178
+ (source_doc_idx , id ) for source_doc_idx , ids in mapped_ids for id in ids
179
+ ]
232
180
233
- def split_text_on_tokens (* , text : str , tokenizer : Tokenizer ) -> list [str ]:
234
- """Split incoming text and return chunks using tokenizer."""
235
- splits : list [str ] = []
236
- input_ids = tokenizer .encode (text )
237
181
start_idx = 0
238
182
cur_idx = min (start_idx + tokenizer .tokens_per_chunk , len (input_ids ))
239
183
chunk_ids = input_ids [start_idx :cur_idx ]
184
+
240
185
while start_idx < len (input_ids ):
241
- splits .append (tokenizer .decode (chunk_ids ))
186
+ chunk_text = tokenizer .decode ([id for _ , id in chunk_ids ])
187
+ doc_indices = list ({doc_idx for doc_idx , _ in chunk_ids })
188
+ result .append (TextChunk (chunk_text , doc_indices , len (chunk_ids )))
242
189
start_idx += tokenizer .tokens_per_chunk - tokenizer .chunk_overlap
243
190
cur_idx = min (start_idx + tokenizer .tokens_per_chunk , len (input_ids ))
244
191
chunk_ids = input_ids [start_idx :cur_idx ]
245
- return splits
192
+
193
+ return result
0 commit comments