@@ -46,6 +46,8 @@ public class TokenTextSplitter extends TextSplitter {
4646
4747 private static final boolean KEEP_SEPARATOR = true ;
4848
49+ private static final List <Character > DEFAULT_PUNCTUATION_MARKS = List .of ('.' , '?' , '!' , '\n' );
50+
4951 private final EncodingRegistry registry = Encodings .newLazyEncodingRegistry ();
5052
5153 private final Encoding encoding = this .registry .getEncoding (EncodingType .CL100K_BASE );
@@ -64,21 +66,27 @@ public class TokenTextSplitter extends TextSplitter {
6466
6567 private final boolean keepSeparator ;
6668
69+ private final List <Character > punctuationMarks ;
70+
6771 public TokenTextSplitter () {
68- this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , KEEP_SEPARATOR );
72+ this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , KEEP_SEPARATOR ,
73+ DEFAULT_PUNCTUATION_MARKS );
6974 }
7075
7176 public TokenTextSplitter (boolean keepSeparator ) {
72- this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , keepSeparator );
77+ this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , keepSeparator ,
78+ DEFAULT_PUNCTUATION_MARKS );
7379 }
7480
7581 public TokenTextSplitter (int chunkSize , int minChunkSizeChars , int minChunkLengthToEmbed , int maxNumChunks ,
76- boolean keepSeparator ) {
82+ boolean keepSeparator , List < Character > punctuationMarks ) {
7783 this .chunkSize = chunkSize ;
7884 this .minChunkSizeChars = minChunkSizeChars ;
7985 this .minChunkLengthToEmbed = minChunkLengthToEmbed ;
8086 this .maxNumChunks = maxNumChunks ;
8187 this .keepSeparator = keepSeparator ;
88+ Assert .notEmpty (punctuationMarks , "punctuationMarks must not be empty" );
89+ this .punctuationMarks = punctuationMarks ;
8290 }
8391
8492 public static Builder builder () {
@@ -109,8 +117,7 @@ protected List<String> doSplit(String text, int chunkSize) {
109117 }
110118
111119 // Find the last period or punctuation mark in the chunk
112- int lastPunctuation = Math .max (chunkText .lastIndexOf ('.' ), Math .max (chunkText .lastIndexOf ('?' ),
113- Math .max (chunkText .lastIndexOf ('!' ), chunkText .lastIndexOf ('\n' ))));
120+ int lastPunctuation = getLastPunctuationIndex (chunkText );
114121
115122 if (lastPunctuation != -1 && lastPunctuation > this .minChunkSizeChars ) {
116123 // Truncate the chunk text at the punctuation mark
@@ -140,6 +147,16 @@ protected List<String> doSplit(String text, int chunkSize) {
140147 return chunks ;
141148 }
142149
150+ protected int getLastPunctuationIndex (String chunkText ) {
151+ // find the max index of any punctuation mark
152+ int maxLastPunctuation = -1 ;
153+ for (Character punctuationMark : this .punctuationMarks ) {
154+ int lastPunctuation = chunkText .lastIndexOf (punctuationMark );
155+ maxLastPunctuation = Math .max (maxLastPunctuation , lastPunctuation );
156+ }
157+ return maxLastPunctuation ;
158+ }
159+
143160 private List <Integer > getEncodedTokens (String text ) {
144161 Assert .notNull (text , "Text must not be null" );
145162 return this .encoding .encode (text ).boxed ();
@@ -164,6 +181,8 @@ public static final class Builder {
164181
165182 private boolean keepSeparator = KEEP_SEPARATOR ;
166183
184+ private List <Character > punctuationMarks = DEFAULT_PUNCTUATION_MARKS ;
185+
167186 private Builder () {
168187 }
169188
@@ -192,9 +211,14 @@ public Builder withKeepSeparator(boolean keepSeparator) {
192211 return this ;
193212 }
194213
214+ public Builder withPunctuationMarks (List <Character > punctuationMarks ) {
215+ this .punctuationMarks = punctuationMarks ;
216+ return this ;
217+ }
218+
195219 public TokenTextSplitter build () {
196220 return new TokenTextSplitter (this .chunkSize , this .minChunkSizeChars , this .minChunkLengthToEmbed ,
197- this .maxNumChunks , this .keepSeparator );
221+ this .maxNumChunks , this .keepSeparator , this . punctuationMarks );
198222 }
199223
200224 }
0 commit comments