Skip to content

Commit 632df5f

Browse files
author
oneby-wang
committed
feat: Support custom punctuation marks in TokenTextSplitter
Signed-off-by: oneby-wang <[email protected]>
1 parent 920e6f4 commit 632df5f

File tree

2 files changed

+70
-6
lines changed

2 files changed

+70
-6
lines changed

spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ public class TokenTextSplitter extends TextSplitter {
4646

4747
private static final boolean KEEP_SEPARATOR = true;
4848

49+
private static final List<Character> DEFAULT_PUNCTUATION_MARKS = List.of('.', '?', '!', '\n');
50+
4951
private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry();
5052

5153
private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE);
@@ -64,21 +66,27 @@ public class TokenTextSplitter extends TextSplitter {
6466

6567
private final boolean keepSeparator;
6668

69+
private final List<Character> punctuationMarks;
70+
6771
public TokenTextSplitter() {
68-
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
72+
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR,
73+
DEFAULT_PUNCTUATION_MARKS);
6974
}
7075

7176
public TokenTextSplitter(boolean keepSeparator) {
72-
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
77+
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator,
78+
DEFAULT_PUNCTUATION_MARKS);
7379
}
7480

7581
public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
76-
boolean keepSeparator) {
82+
boolean keepSeparator, List<Character> punctuationMarks) {
7783
this.chunkSize = chunkSize;
7884
this.minChunkSizeChars = minChunkSizeChars;
7985
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
8086
this.maxNumChunks = maxNumChunks;
8187
this.keepSeparator = keepSeparator;
88+
Assert.notEmpty(punctuationMarks, "punctuationMarks must not be empty");
89+
this.punctuationMarks = punctuationMarks;
8290
}
8391

8492
public static Builder builder() {
@@ -109,8 +117,7 @@ protected List<String> doSplit(String text, int chunkSize) {
109117
}
110118

111119
// Find the last period or punctuation mark in the chunk
112-
int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'),
113-
Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n'))));
120+
int lastPunctuation = getLastPunctuationIndex(chunkText);
114121

115122
if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) {
116123
// Truncate the chunk text at the punctuation mark
@@ -140,6 +147,16 @@ protected List<String> doSplit(String text, int chunkSize) {
140147
return chunks;
141148
}
142149

150+
protected int getLastPunctuationIndex(String chunkText) {
151+
// find the max index of any punctuation mark
152+
int maxLastPunctuation = -1;
153+
for (Character punctuationMark : this.punctuationMarks) {
154+
int lastPunctuation = chunkText.lastIndexOf(punctuationMark);
155+
maxLastPunctuation = Math.max(maxLastPunctuation, lastPunctuation);
156+
}
157+
return maxLastPunctuation;
158+
}
159+
143160
private List<Integer> getEncodedTokens(String text) {
144161
Assert.notNull(text, "Text must not be null");
145162
return this.encoding.encode(text).boxed();
@@ -164,6 +181,8 @@ public static final class Builder {
164181

165182
private boolean keepSeparator = KEEP_SEPARATOR;
166183

184+
private List<Character> punctuationMarks = DEFAULT_PUNCTUATION_MARKS;
185+
167186
private Builder() {
168187
}
169188

@@ -192,9 +211,14 @@ public Builder withKeepSeparator(boolean keepSeparator) {
192211
return this;
193212
}
194213

214+
public Builder withPunctuationMarks(List<Character> punctuationMarks) {
215+
this.punctuationMarks = punctuationMarks;
216+
return this;
217+
}
218+
195219
public TokenTextSplitter build() {
196220
return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed,
197-
this.maxNumChunks, this.keepSeparator);
221+
this.maxNumChunks, this.keepSeparator, this.punctuationMarks);
198222
}
199223

200224
}

spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,44 @@ public void testTokenTextSplitterBuilderWithAllFields() {
125125
assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1");
126126
}
127127

128+
@Test
129+
public void testTokenTextSplitterWithCustomPunctuationMarks() {
130+
var contentFormatter1 = DefaultContentFormatter.defaultConfig();
131+
var contentFormatter2 = DefaultContentFormatter.defaultConfig();
132+
133+
assertThat(contentFormatter1).isNotSameAs(contentFormatter2);
134+
135+
var doc1 = new Document("Here, we set custom punctuation marks。?!. We just want to test it works or not?");
136+
doc1.setContentFormatter(contentFormatter1);
137+
138+
var doc2 = new Document("And more, we add protected method getLastPunctuationIndex in TokenTextSplitter class!"
139+
+ "The subclasses can override this method to achieve their own business logic。We just want to test it works or not?");
140+
doc2.setContentFormatter(contentFormatter2);
141+
142+
var tokenTextSplitter = TokenTextSplitter.builder()
143+
.withChunkSize(10)
144+
.withMinChunkSizeChars(5)
145+
.withMinChunkLengthToEmbed(3)
146+
.withMaxNumChunks(50)
147+
.withKeepSeparator(true)
148+
.withPunctuationMarks(List.of('。', '?', '!'))
149+
.build();
150+
151+
var chunks = tokenTextSplitter.apply(List.of(doc1, doc2));
152+
153+
assertThat(chunks.size()).isEqualTo(7);
154+
155+
// Doc 1
156+
assertThat(chunks.get(0).getText()).isEqualTo("Here, we set custom punctuation marks。?!");
157+
assertThat(chunks.get(1).getText()).isEqualTo(". We just want to test it works or not");
158+
159+
// Doc 2
160+
assertThat(chunks.get(2).getText()).isEqualTo("And more, we add protected method getLastPunctuation");
161+
assertThat(chunks.get(3).getText()).isEqualTo("Index in TokenTextSplitter class!");
162+
assertThat(chunks.get(4).getText()).isEqualTo("The subclasses can override this method to achieve their own");
163+
assertThat(chunks.get(5).getText()).isEqualTo("business logic。");
164+
assertThat(chunks.get(6).getText()).isEqualTo("We just want to test it works or not?");
165+
166+
}
167+
128168
}

0 commit comments

Comments
 (0)