Skip to content

Commit

Permalink
aws/bedrock: review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
rockwotj committed Aug 21, 2024
1 parent 629e97e commit 10d3009
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions internal/impl/aws/processor_bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import (
)

const (
bedFieldModel = "model"
bedFieldUserPrompt = "prompt"
bedFieldSystemPrompt = "system_prompt"
bedFieldMaxTokens = "max_tokens"
bedFieldTemp = "stop"
bedFieldStop = "temperature"
bedFieldTopP = "top_p"
bedpFieldModel = "model"
bedpFieldUserPrompt = "prompt"
bedpFieldSystemPrompt = "system_prompt"
bedpFieldMaxTokens = "max_tokens"
bedpFieldTemp = "stop"
bedpFieldStop = "temperature"
bedpFieldTopP = "top_p"
)

func init() {
Expand All @@ -45,32 +45,32 @@ For more information, see the https://docs.aws.amazon.com/bedrock/latest/usergui
Categories("AI").
Version("4.34.0").
Fields(config.SessionFields()...).
Field(service.NewStringField(bedFieldModel).
Field(service.NewStringField(bedpFieldModel).
Examples("amazon.titan-text-express-v1", "anthropic.claude-3-5-sonnet-20240620-v1:0", "cohere.command-text-v14", "meta.llama3-1-70b-instruct-v1:0", "mistral.mistral-large-2402-v1:0").
Description("The model ID to use. For a full list see the https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html[AWS Bedrock documentation^].")).
Field(service.NewStringField(bedFieldUserPrompt).
Field(service.NewStringField(bedpFieldUserPrompt).
Description("The prompt you want to generate a response for. By default, the processor submits the entire payload as a string.").
Optional()).
Field(service.NewStringField(bedFieldSystemPrompt).
Field(service.NewStringField(bedpFieldSystemPrompt).
Optional().
Description("The system prompt to submit to the AWS Bedrock LLM.")).
Field(service.NewIntField(bedFieldMaxTokens).
Field(service.NewIntField(bedpFieldMaxTokens).
Optional().
Description("The maximum number of tokens to allow in the generated response.")).
Field(service.NewFloatField(bedFieldTemp).
Description("The maximum number of tokens to allow in the generated response.").
LintRule(`root = this < 1 { ["field must be greater than or equal to 1"] }`)).
Field(service.NewFloatField(bedpFieldTemp).
Optional().
Description("The likelihood of the model selecting higher-probability options while generating a response. A lower value makes the model omre likely to choose higher-probability options, while a higher value makes the model more likely to choose lower-probability options.").
LintRule(`root = if this < 0 || this > 1 { ["field must be between 0.0-1.0"] }`)).
Field(service.NewStringListField(bedFieldStop).
Field(service.NewStringListField(bedpFieldStop).
Optional().
Advanced().
Description("A list of stop sequences. A stop sequence is a sequence of characters that causes the model to stop generating the response.")).
Field(service.NewFloatField(bedFieldTopP).
Field(service.NewFloatField(bedpFieldTopP).
Optional().
Advanced().
Description("The percentage of most-likely candidates that the model considers for the next token. For example, if you choose a value of 0.8, the model selects from the top 80% of the probability distribution of tokens that could be next in the sequence. ").
LintRule(`root = if this < 0 || this > 1 { ["field must be between 0.0-1.0"] }`))

}

func newBedrockProcessor(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) {
Expand All @@ -79,53 +79,53 @@ func newBedrockProcessor(conf *service.ParsedConfig, mgr *service.Resources) (se
return nil, err
}
client := bedrockruntime.NewFromConfig(aconf)
model, err := conf.FieldString(bedFieldModel)
model, err := conf.FieldString(bedpFieldModel)
if err != nil {
return nil, err
}
p := &bedrockProcessor{
client: client,
model: model,
}
if conf.Contains(bedFieldUserPrompt) {
pf, err := conf.FieldInterpolatedString(bedFieldUserPrompt)
if conf.Contains(bedpFieldUserPrompt) {
pf, err := conf.FieldInterpolatedString(bedpFieldUserPrompt)
if err != nil {
return nil, err
}
p.userPrompt = pf
}
if conf.Contains(bedFieldSystemPrompt) {
pf, err := conf.FieldInterpolatedString(bedFieldSystemPrompt)
if conf.Contains(bedpFieldSystemPrompt) {
pf, err := conf.FieldInterpolatedString(bedpFieldSystemPrompt)
if err != nil {
return nil, err
}
p.systemPrompt = pf
}
if conf.Contains(bedFieldMaxTokens) {
v, err := conf.FieldInt(bedFieldMaxTokens)
if conf.Contains(bedpFieldMaxTokens) {
v, err := conf.FieldInt(bedpFieldMaxTokens)
if err != nil {
return nil, err
}
mt := int32(v)
p.maxTokens = &mt
}
if conf.Contains(bedFieldTemp) {
v, err := conf.FieldFloat(bedFieldTemp)
if conf.Contains(bedpFieldTemp) {
v, err := conf.FieldFloat(bedpFieldTemp)
if err != nil {
return nil, err
}
t := float32(v)
p.temp = &t
}
if conf.Contains(bedFieldStop) {
stop, err := conf.FieldStringList(bedFieldStop)
if conf.Contains(bedpFieldStop) {
stop, err := conf.FieldStringList(bedpFieldStop)
if err != nil {
return nil, err
}
p.stop = stop
}
if conf.Contains(bedFieldTopP) {
v, err := conf.FieldFloat(bedFieldTopP)
if conf.Contains(bedpFieldTopP) {
v, err := conf.FieldFloat(bedpFieldTopP)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -174,7 +174,7 @@ func (b *bedrockProcessor) Process(ctx context.Context, msg *service.Message) (s
if b.systemPrompt != nil {
prompt, err := b.systemPrompt.TryString(msg)
if err != nil {
return nil, fmt.Errorf("unable to interpolate `%s`: %w", bedFieldSystemPrompt, err)
return nil, fmt.Errorf("unable to interpolate `%s`: %w", bedpFieldSystemPrompt, err)
}
input.System = []bedrocktypes.SystemContentBlock{
&bedrocktypes.SystemContentBlockMemberText{Value: prompt},
Expand Down

0 comments on commit 10d3009

Please sign in to comment.