Skip to content

Initial APM side for aws bedrock #4937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
04f3e35
Initial APM side for aws bedrock
yahya-mouman Nov 25, 2024
ab53da2
Merge branch 'master' into yahya/add-bedrock-support
yahya-mouman Nov 26, 2024
a79fd29
add extract response tags
yahya-mouman Dec 2, 2024
b55cd39
add extract response tags
yahya-mouman Dec 2, 2024
2b24cc8
remove hook
yahya-mouman Dec 2, 2024
71c947b
Update packages/datadog-plugin-aws-sdk/src/services/bedrockruntime.js
yahya-mouman Dec 2, 2024
34fb1b3
added example test for invoke amazon
yahya-mouman Dec 9, 2024
f48f134
added example test for invoke amazon
yahya-mouman Dec 9, 2024
022e5fa
update test with todos
yahya-mouman Dec 9, 2024
5179c23
update test with todos
yahya-mouman Dec 9, 2024
154b81e
Drop underscore in name
yahya-mouman Dec 11, 2024
b3fdc39
Update packages/datadog-plugin-aws-sdk/test/bedrock.spec.js
yahya-mouman Dec 11, 2024
11a9730
Constants normalization
yahya-mouman Dec 11, 2024
c169a71
Merge remote-tracking branch 'origin/yahya/add-bedrock-support' into …
yahya-mouman Dec 11, 2024
5d84757
Add Mistral AI
yahya-mouman Dec 11, 2024
a052bc9
Add aws bedrock rec
yahya-mouman Dec 16, 2024
47f8cc5
remove file
yahya-mouman Dec 16, 2024
aab5329
added tests with mocked responses
yahya-mouman Dec 16, 2024
951bb75
added jamba support to AI21 lab
yahya-mouman Dec 16, 2024
69b1cfe
update bedrock version
yahya-mouman Dec 16, 2024
82524fb
Update tests
yahya-mouman Dec 18, 2024
61a8c9c
remove only
yahya-mouman Dec 18, 2024
da3d117
Update response extractions to only pick up first completion/generation
yahya-mouman Dec 19, 2024
2d7af6b
Update packages/datadog-plugin-aws-sdk/src/services/bedrockruntime.js
yahya-mouman Dec 20, 2024
82f5625
Merge branch 'master' into yahya/add-bedrock-support
yahya-mouman Dec 20, 2024
13eb6a5
Merge remote-tracking branch 'origin/yahya/add-bedrock-support' into …
yahya-mouman Dec 23, 2024
32899e2
Change from constants to a struct for model provider
yahya-mouman Dec 23, 2024
ba0a816
format
yahya-mouman Dec 23, 2024
c2f458b
switch case
yahya-mouman Dec 23, 2024
79cc526
Add classes for generations and requestParams
yahya-mouman Dec 24, 2024
46c2d49
Make constructors name object based. and stringify prompt if it's not…
yahya-mouman Dec 24, 2024
1b40ab8
stringify message if it's not a string
yahya-mouman Dec 24, 2024
a955185
es lint
yahya-mouman Dec 24, 2024
b610601
fix bad variable name
yahya-mouman Dec 24, 2024
7b810f9
add extra tags
yahya-mouman Dec 24, 2024
c54e983
camelCase and lint
yahya-mouman Dec 24, 2024
6220304
camelCase and lint
yahya-mouman Dec 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/datadog-instrumentations/src/aws-sdk.js
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ function getChannelSuffix (name) {
'sns',
'sqs',
'states',
'stepfunctions'
'stepfunctions',
'bedrock runtime'
].includes(name)
? name
: 'default'
Expand Down
295 changes: 295 additions & 0 deletions packages/datadog-plugin-aws-sdk/src/services/bedrockruntime.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
'use strict'

const BaseAwsSdkPlugin = require('../base')
const log = require('../../../dd-trace/src/log')

const PROVIDER = {
AI21: 'AI21',
AMAZON: 'AMAZON',
ANTHROPIC: 'ANTHROPIC',
COHERE: 'COHERE',
META: 'META',
STABILITY: 'STABILITY',
MISTRAL: 'MISTRAL'
}

const enabledOperations = ['invokeModel']

class BedrockRuntime extends BaseAwsSdkPlugin {
static get id () { return 'bedrock runtime' }

isEnabled (request) {
const operation = request.operation
if (!enabledOperations.includes(operation)) {
return false
}

return super.isEnabled(request)
}

generateTags (params, operation, response) {
let tags = {}
let modelName = ''
let modelProvider = ''
const modelMeta = params.modelId.split('.')
if (modelMeta.length === 2) {
[modelProvider, modelName] = modelMeta
modelProvider = modelProvider.toUpperCase()
} else {
[, modelProvider, modelName] = modelMeta
modelProvider = modelProvider.toUpperCase()
}

const shouldSetChoiceIds = modelProvider === PROVIDER.COHERE && !modelName.includes('embed')

const requestParams = extractRequestParams(params, modelProvider)
const textAndResponseReason = extractTextAndResponseReason(response, modelProvider, modelName, shouldSetChoiceIds)

tags = buildTagsFromParams(requestParams, textAndResponseReason, modelProvider, modelName, operation)

return tags
}
}

class Generation {
constructor ({ message = '', finishReason = '', choiceId = '' } = {}) {
// stringify message as it could be a single generated message as well as a list of embeddings
this.message = typeof message === 'string' ? message : JSON.stringify(message) || ''
this.finishReason = finishReason || ''
this.choiceId = choiceId || undefined
}
}

class RequestParams {
constructor ({
prompt = '',
temperature = undefined,
topP = undefined,
maxTokens = undefined,
stopSequences = [],
inputType = '',
truncate = '',
stream = '',
n = undefined
} = {}) {
// TODO: set a truncation limit to prompt
// stringify prompt as it could be a single prompt as well as a list of message objects
this.prompt = typeof prompt === 'string' ? prompt : JSON.stringify(prompt) || ''
this.temperature = temperature !== undefined ? temperature : undefined
this.topP = topP !== undefined ? topP : undefined
this.maxTokens = maxTokens !== undefined ? maxTokens : undefined
Comment on lines +78 to +80
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think these checks (and the one on line 85) can just be this.temperature = temperature, or whatever the name is, as we default it do undefined anyways. let me know if i'm reading this wrong, though!

this.stopSequences = stopSequences || []
this.inputType = inputType || ''
this.truncate = truncate || ''
this.stream = stream || ''
this.n = n !== undefined ? n : undefined
}
}

function extractRequestParams (params, provider) {
const requestBody = JSON.parse(params.body)
const modelId = params.modelId

switch (provider) {
case PROVIDER.AI21: {
let userPrompt = requestBody.prompt
if (modelId.includes('jamba')) {
for (const message of requestBody.messages) {
if (message.role === 'user') {
userPrompt = message.content // Return the content of the most recent user message
}
}
}
return new RequestParams({
prompt: userPrompt,
temperature: requestBody.temperature,
topP: requestBody.top_p,
maxTokens: requestBody.max_tokens,
stopSequences: requestBody.stop_sequences
})
}
case PROVIDER.AMAZON: {
if (modelId.includes('embed')) {
return new RequestParams({ prompt: requestBody.inputText })
}
const textGenerationConfig = requestBody.textGenerationConfig || {}
return new RequestParams({
prompt: requestBody.inputText,
temperature: textGenerationConfig.temperature,
topP: textGenerationConfig.topP,
maxTokens: textGenerationConfig.maxTokenCount,
stopSequences: textGenerationConfig.stopSequences
})
}
case PROVIDER.ANTHROPIC: {
const prompt = requestBody.prompt || requestBody.messages
return new RequestParams({
prompt,
temperature: requestBody.temperature,
topP: requestBody.top_p,
maxTokens: requestBody.max_tokens_to_sample,
stopSequences: requestBody.stop_sequences
})
}
case PROVIDER.COHERE: {
if (modelId.includes('embed')) {
return new RequestParams({
prompt: requestBody.texts,
inputType: requestBody.input_type,
truncate: requestBody.truncate
})
}
return new RequestParams({
prompt: requestBody.prompt,
temperature: requestBody.temperature,
topP: requestBody.p,
maxTokens: requestBody.max_tokens,
stopSequences: requestBody.stop_sequences,
stream: requestBody.stream,
n: requestBody.num_generations
})
}
case PROVIDER.META: {
return new RequestParams({
prompt: requestBody.prompt,
temperature: requestBody.temperature,
topP: requestBody.top_p,
maxTokens: requestBody.max_gen_len
})
}
case PROVIDER.MISTRAL: {
return new RequestParams({
prompt: requestBody.prompt,
temperature: requestBody.temperature,
topP: requestBody.top_p,
maxTokens: requestBody.max_tokens,
stopSequences: requestBody.stop,
topK: requestBody.top_k
})
}
case PROVIDER.STABILITY: {
return new RequestParams()
}
default: {
return new RequestParams()
}
}
}

function extractTextAndResponseReason (response, provider, modelName, shouldSetChoiceIds) {
const body = JSON.parse(Buffer.from(response.body).toString('utf8'))

try {
switch (provider) {
case PROVIDER.AI21: {
if (modelName.includes('jamba')) {
const generations = body.choices || []
if (generations.length > 0) {
const generation = generations[0]
return new Generation({
message: generation.message,
finishReason: generation.finish_reason,
choiceId: shouldSetChoiceIds ? generation.id : undefined
})
}
}
const completions = body.completions || []
if (completions.length > 0) {
const completion = completions[0]
return new Generation({
message: completion.data?.text,
finishReason: completion?.finishReason,
choiceId: shouldSetChoiceIds ? completion?.id : undefined
})
}
return new Generation()
}
case PROVIDER.AMAZON: {
if (modelName.includes('embed')) {
return new Generation({ message: body.embedding })
}
const results = body.results || []
if (results.length > 0) {
const result = results[0]
return new Generation({ message: result.outputText, finishReason: result.completionReason })
}
break
}
case PROVIDER.ANTHROPIC: {
return new Generation({ message: body.completion || body.content, finishReason: body.stop_reason })
}
case PROVIDER.COHERE: {
if (modelName.includes('embed')) {
const embeddings = body.embeddings || [[]]
if (embeddings.length > 0) {
return new Generation({ message: embeddings[0] })
}
}
const generations = body.generations || []
if (generations.length > 0) {
const generation = generations[0]
return new Generation({
message: generation.text,
finishReason: generation.finish_reason,
choiceId: shouldSetChoiceIds ? generation.id : undefined
})
}
break
}
case PROVIDER.META: {
return new Generation({ message: body.generation, finishReason: body.stop_reason })
}
case PROVIDER.MISTRAL: {
const mistralGenerations = body.outputs || []
if (mistralGenerations.length > 0) {
const generation = mistralGenerations[0]
return new Generation({ message: generation.text, finishReason: generation.stop_reason })
}
break
}
case PROVIDER.STABILITY: {
return new Generation()
}
default: {
return new Generation()
}
}
} catch (error) {
log.warn('Unable to extract text/finishReason from response body. Defaulting to empty text/finishReason.')
return new Generation()
}

return new Generation()
}

function buildTagsFromParams (requestParams, textAndResponseReason, modelProvider, modelName, operation) {
const tags = {}

// add request tags
tags['resource.name'] = operation
tags['aws.bedrock.request.model'] = modelName
tags['aws.bedrock.request.model_provider'] = modelProvider
tags['aws.bedrock.request.prompt'] = requestParams.prompt
tags['aws.bedrock.request.temperature'] = requestParams.temperature
tags['aws.bedrock.request.top_p'] = requestParams.topP
tags['aws.bedrock.request.max_tokens'] = requestParams.maxTokens
tags['aws.bedrock.request.stop_sequences'] = requestParams.stopSequences
tags['aws.bedrock.request.input_type'] = requestParams.inputType
tags['aws.bedrock.request.truncate'] = requestParams.truncate
tags['aws.bedrock.request.stream'] = requestParams.stream
tags['aws.bedrock.request.n'] = requestParams.n

// add response tags
if (modelName.includes('embed')) {
tags['aws.bedrock.response.embedding_length'] = textAndResponseReason.message.length
}
if (textAndResponseReason.choiceId) {
tags['aws.bedrock.response.choices.id'] = textAndResponseReason.choiceId
}
tags['aws.bedrock.response.choices.text'] = textAndResponseReason.message
tags['aws.bedrock.response.choices.finish_reason'] = textAndResponseReason.finishReason

return tags
}

module.exports = BedrockRuntime
1 change: 1 addition & 0 deletions packages/datadog-plugin-aws-sdk/src/services/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ exports.sns = require('./sns')
exports.sqs = require('./sqs')
exports.states = require('./states')
exports.stepfunctions = require('./stepfunctions')
exports.bedrockruntime = require('./bedrockruntime')
exports.default = require('./default')
Loading
Loading