Skip to content

Commit ea32d13

Browse files
committed
Provide detailed Exception message when token count exceeds max
Closes #x
1 parent 374c09e commit ea32d13

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

spring-ai-commons/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
* @author Laura Trotta
5252
* @author Jihoon Kim
5353
* @author Yanming Zhou
54+
* @author John Blum
5455
* @since 1.0.0
5556
*/
5657
public class TokenCountBatchingStrategy implements BatchingStrategy {
@@ -148,8 +149,9 @@ public List<List<Document>> batch(List<Document> documents) {
148149
int tokenCount = this.tokenCountEstimator
149150
.estimate(document.getFormattedContent(this.contentFormatter, this.metadataMode));
150151
if (tokenCount > this.maxInputTokenCount) {
151-
throw new IllegalArgumentException(
152-
"Tokens in a single document exceeds the maximum number of allowed input tokens");
152+
String message = "Tokens in a single document [%d] exceeds the maximum number of allowed input tokens [%d]"
153+
.formatted(tokenCount, this.maxInputTokenCount);
154+
throw new IllegalArgumentException(message);
153155
}
154156
documentTokens.put(document, tokenCount);
155157
}

spring-ai-commons/src/test/java/org/springframework/ai/embedding/TokenCountBatchingStrategyTests.java

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,31 @@
2222

2323
import org.junit.jupiter.api.Test;
2424

25+
import org.springframework.ai.document.ContentFormatter;
2526
import org.springframework.ai.document.Document;
27+
import org.springframework.ai.document.MetadataMode;
28+
import org.springframework.ai.tokenizer.TokenCountEstimator;
2629
import org.springframework.core.io.DefaultResourceLoader;
2730
import org.springframework.core.io.Resource;
2831

2932
import static org.assertj.core.api.Assertions.assertThat;
33+
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
3034
import static org.assertj.core.api.Assertions.assertThatThrownBy;
35+
import static org.mockito.ArgumentMatchers.any;
36+
import static org.mockito.ArgumentMatchers.anyString;
37+
import static org.mockito.ArgumentMatchers.eq;
38+
import static org.mockito.Mockito.doReturn;
39+
import static org.mockito.Mockito.mock;
40+
import static org.mockito.Mockito.times;
41+
import static org.mockito.Mockito.verify;
42+
import static org.mockito.Mockito.verifyNoInteractions;
43+
import static org.mockito.Mockito.verifyNoMoreInteractions;
3144

3245
/**
33-
* Basic unit test for {@link TokenCountBatchingStrategy}.
46+
* Basic unit tests for {@link TokenCountBatchingStrategy}.
3447
*
3548
* @author Soby Chacko
49+
* @author John Blum
3650
*/
3751
public class TokenCountBatchingStrategyTests {
3852

@@ -54,4 +68,27 @@ void batchEmbeddingWithLargeDocumentExceedsMaxTokenSize() throws IOException {
5468
.isInstanceOf(IllegalArgumentException.class);
5569
}
5670

71+
@Test
72+
void documentTokenCountExceedsConfiguredMaxTokenCount() {
73+
74+
Document mockDocument = mock(Document.class);
75+
ContentFormatter mockContentFormatter = mock(ContentFormatter.class);
76+
TokenCountEstimator mockTokenCountEstimator = mock(TokenCountEstimator.class);
77+
78+
doReturn(10).when(mockTokenCountEstimator).estimate(anyString());
79+
doReturn("test").when(mockDocument).getFormattedContent(any(), any());
80+
81+
TokenCountBatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(mockTokenCountEstimator, 9, 0.0d,
82+
mockContentFormatter, MetadataMode.EMBED);
83+
84+
assertThatIllegalArgumentException().isThrownBy(() -> batchingStrategy.batch(List.of(mockDocument)))
85+
.withMessage("Tokens in a single document [10] exceeds the maximum number of allowed input tokens [9]")
86+
.withNoCause();
87+
88+
verify(mockDocument, times(1)).getFormattedContent(eq(mockContentFormatter), eq(MetadataMode.EMBED));
89+
verify(mockTokenCountEstimator, times(1)).estimate(eq("test"));
90+
verifyNoMoreInteractions(mockDocument, mockTokenCountEstimator);
91+
verifyNoInteractions(mockContentFormatter);
92+
}
93+
5794
}

0 commit comments

Comments
 (0)