Skip to content

Commit 9e339b4

Browse files
committed
Provide detailed Exception message when token count exceeds max
Closes #4835 Signed-off-by: John Blum <[email protected]>
1 parent 374c09e commit 9e339b4

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

spring-ai-commons/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
* @author Laura Trotta
5252
* @author Jihoon Kim
5353
* @author Yanming Zhou
54+
* @author John Blum
5455
* @since 1.0.0
5556
*/
5657
public class TokenCountBatchingStrategy implements BatchingStrategy {
@@ -148,8 +149,9 @@ public List<List<Document>> batch(List<Document> documents) {
148149
int tokenCount = this.tokenCountEstimator
149150
.estimate(document.getFormattedContent(this.contentFormatter, this.metadataMode));
150151
if (tokenCount > this.maxInputTokenCount) {
151-
throw new IllegalArgumentException(
152-
"Tokens in a single document exceeds the maximum number of allowed input tokens");
152+
String message = "Tokens [%d] in document with ID [%s] exceeds maximum number of allowed input tokens [%d]"
153+
.formatted(tokenCount, document.getId(), this.maxInputTokenCount);
154+
throw new IllegalArgumentException(message);
153155
}
154156
documentTokens.put(document, tokenCount);
155157
}

spring-ai-commons/src/test/java/org/springframework/ai/embedding/TokenCountBatchingStrategyTests.java

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,30 @@
2222

2323
import org.junit.jupiter.api.Test;
2424

25+
import org.springframework.ai.document.ContentFormatter;
2526
import org.springframework.ai.document.Document;
27+
import org.springframework.ai.document.MetadataMode;
28+
import org.springframework.ai.tokenizer.TokenCountEstimator;
2629
import org.springframework.core.io.DefaultResourceLoader;
2730
import org.springframework.core.io.Resource;
2831

2932
import static org.assertj.core.api.Assertions.assertThat;
33+
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
3034
import static org.assertj.core.api.Assertions.assertThatThrownBy;
35+
import static org.mockito.ArgumentMatchers.any;
36+
import static org.mockito.ArgumentMatchers.anyString;
37+
import static org.mockito.ArgumentMatchers.eq;
38+
import static org.mockito.Mockito.doReturn;
39+
import static org.mockito.Mockito.mock;
40+
import static org.mockito.Mockito.times;
41+
import static org.mockito.Mockito.verify;
42+
import static org.mockito.Mockito.verifyNoMoreInteractions;
3143

3244
/**
33-
* Basic unit test for {@link TokenCountBatchingStrategy}.
45+
* Basic unit tests for {@link TokenCountBatchingStrategy}.
3446
*
3547
* @author Soby Chacko
48+
* @author John Blum
3649
*/
3750
public class TokenCountBatchingStrategyTests {
3851

@@ -54,4 +67,28 @@ void batchEmbeddingWithLargeDocumentExceedsMaxTokenSize() throws IOException {
5467
.isInstanceOf(IllegalArgumentException.class);
5568
}
5669

70+
@Test
71+
void documentTokenCountExceedsConfiguredMaxTokenCount() {
72+
73+
Document mockDocument = mock(Document.class);
74+
ContentFormatter mockContentFormatter = mock(ContentFormatter.class);
75+
TokenCountEstimator mockTokenCountEstimator = mock(TokenCountEstimator.class);
76+
77+
doReturn("123abc").when(mockDocument).getId();
78+
doReturn(10).when(mockTokenCountEstimator).estimate(anyString());
79+
doReturn("test").when(mockDocument).getFormattedContent(any(), any());
80+
81+
TokenCountBatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(mockTokenCountEstimator, 9, 0.0d,
82+
mockContentFormatter, MetadataMode.EMBED);
83+
84+
assertThatIllegalArgumentException().isThrownBy(() -> batchingStrategy.batch(List.of(mockDocument)))
85+
.withMessage("Tokens [10] in document with ID [123abc] exceeds maximum number of allowed input tokens [9]")
86+
.withNoCause();
87+
88+
verify(mockDocument, times(1)).getId();
89+
verify(mockDocument, times(1)).getFormattedContent(eq(mockContentFormatter), eq(MetadataMode.EMBED));
90+
verify(mockTokenCountEstimator, times(1)).estimate(eq("test"));
91+
verifyNoMoreInteractions(mockDocument, mockTokenCountEstimator);
92+
}
93+
5794
}

0 commit comments

Comments
 (0)