Skip to content

Commit f4ff581

Browse files
Stuart Loxtonchedim
authored andcommitted
[Enhancement] - Add requestMetadata to Bedrock converse request
- Derive requestMetadata from user metadata Signed-off-by: Stuart Loxton <[email protected]> Signed-off-by: Stuart Loxton <[email protected]>
1 parent a70500d commit f4ff581

File tree

5 files changed

+53
-7
lines changed

5 files changed

+53
-7
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockChatOptions.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ public class BedrockChatOptions implements ToolCallingChatOptions {
5454
@JsonProperty("presencePenalty")
5555
private Double presencePenalty;
5656

57+
@JsonIgnore
58+
private Map<String, String> requestParameters = new HashMap<>();
59+
5760
@JsonProperty("stopSequences")
5861
private List<String> stopSequences;
5962

@@ -88,6 +91,7 @@ public static BedrockChatOptions fromOptions(BedrockChatOptions fromOptions) {
8891
.frequencyPenalty(fromOptions.getFrequencyPenalty())
8992
.maxTokens(fromOptions.getMaxTokens())
9093
.presencePenalty(fromOptions.getPresencePenalty())
94+
.requestParameters(new HashMap<>(fromOptions.getRequestParameters()))
9195
.stopSequences(
9296
fromOptions.getStopSequences() != null ? new ArrayList<>(fromOptions.getStopSequences()) : null)
9397
.temperature(fromOptions.getTemperature())
@@ -127,6 +131,14 @@ public void setMaxTokens(Integer maxTokens) {
127131
this.maxTokens = maxTokens;
128132
}
129133

134+
public Map<String, String> getRequestParameters() {
135+
return this.requestParameters;
136+
}
137+
138+
public void setRequestParameters(Map<String, String> requestParameters) {
139+
this.requestParameters = requestParameters;
140+
}
141+
130142
@Override
131143
public Double getPresencePenalty() {
132144
return this.presencePenalty;
@@ -242,6 +254,7 @@ public boolean equals(Object o) {
242254
return Objects.equals(this.model, that.model) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty)
243255
&& Objects.equals(this.maxTokens, that.maxTokens)
244256
&& Objects.equals(this.presencePenalty, that.presencePenalty)
257+
&& Objects.equals(this.requestParameters, that.requestParameters)
245258
&& Objects.equals(this.stopSequences, that.stopSequences)
246259
&& Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topK, that.topK)
247260
&& Objects.equals(this.topP, that.topP) && Objects.equals(this.toolCallbacks, that.toolCallbacks)
@@ -251,9 +264,9 @@ public boolean equals(Object o) {
251264

252265
@Override
253266
public int hashCode() {
254-
return Objects.hash(this.model, this.frequencyPenalty, this.maxTokens, this.presencePenalty, this.stopSequences,
255-
this.temperature, this.topK, this.topP, this.toolCallbacks, this.toolNames, this.toolContext,
256-
this.internalToolExecutionEnabled);
267+
return Objects.hash(this.model, this.frequencyPenalty, this.maxTokens, this.presencePenalty,
268+
this.requestParameters, this.stopSequences, this.temperature, this.topK, this.topP, this.toolCallbacks,
269+
this.toolNames, this.toolContext, this.internalToolExecutionEnabled);
257270
}
258271

259272
public static class Builder {
@@ -280,6 +293,11 @@ public Builder presencePenalty(Double presencePenalty) {
280293
return this;
281294
}
282295

296+
public Builder requestParameters(Map<String, String> requestParameters) {
297+
this.options.requestParameters = requestParameters;
298+
return this;
299+
}
300+
283301
public Builder stopSequences(List<String> stopSequences) {
284302
this.options.stopSequences = stopSequences;
285303
return this;

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,13 +429,17 @@ else if (message.getMessageType() == MessageType.TOOL) {
429429
Document additionalModelRequestFields = ConverseApiUtils
430430
.getChatOptionsAdditionalModelRequestFields(this.defaultOptions, prompt.getOptions());
431431

432+
Map<String, String> requestMetadata = ConverseApiUtils
433+
.getRequestMetadata(prompt.getUserMessage().getMetadata());
434+
432435
return ConverseRequest.builder()
433436
.modelId(updatedRuntimeOptions.getModel())
434437
.inferenceConfig(inferenceConfiguration)
435438
.messages(instructionMessages)
436439
.system(systemMessages)
437440
.additionalModelRequestFields(additionalModelRequestFields)
438441
.toolConfig(toolConfiguration)
442+
.requestMetadata(requestMetadata)
439443
.build();
440444
}
441445

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/api/ConverseApiUtils.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,26 @@ else if (value instanceof Map mapValue) {
384384
}
385385
}
386386

387+
@SuppressWarnings("unchecked")
388+
public static Map<String, String> getRequestMetadata(Map<String, Object> metadata) {
389+
390+
if (metadata.isEmpty()) {
391+
return Map.of();
392+
}
393+
394+
Map<String, String> result = new HashMap<>();
395+
for (Map.Entry<String, Object> entry : metadata.entrySet()) {
396+
String key = entry.getKey();
397+
Object value = entry.getValue();
398+
399+
if (key != null && value != null) {
400+
result.put(key, value.toString());
401+
}
402+
}
403+
404+
return result;
405+
}
406+
387407
private static Document convertMapToDocument(Map<String, Object> value) {
388408
Map<String, Document> attr = value.entrySet()
389409
.stream()

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockChatOptionsTests.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,18 @@ void testBuilderWithAllFields() {
3737
.frequencyPenalty(0.0)
3838
.maxTokens(100)
3939
.presencePenalty(0.0)
40+
.requestParameters(Map.of("requestId", "1234"))
4041
.stopSequences(List.of("stop1", "stop2"))
4142
.temperature(0.7)
4243
.topP(0.8)
4344
.topK(50)
4445
.build();
4546

4647
assertThat(options)
47-
.extracting("model", "frequencyPenalty", "maxTokens", "presencePenalty", "stopSequences", "temperature",
48-
"topP", "topK")
49-
.containsExactly("test-model", 0.0, 100, 0.0, List.of("stop1", "stop2"), 0.7, 0.8, 50);
48+
.extracting("model", "frequencyPenalty", "maxTokens", "presencePenalty", "requestParameters",
49+
"stopSequences", "temperature", "topP", "topK")
50+
.containsExactly("test-model", 0.0, 100, 0.0, Map.of("requestId", "1234"), List.of("stop1", "stop2"), 0.7,
51+
0.8, 50);
5052
}
5153

5254
@Test

models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ void call() {
7070
.system(s -> s.text(this.systemTextResource)
7171
.param("name", "Bob")
7272
.param("voice", "pirate"))
73-
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
73+
.user(u -> u.text("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
74+
.metadata("requestId", "12345")
75+
)
7476
.call()
7577
.chatResponse();
7678
// @formatter:on

0 commit comments

Comments
 (0)