Skip to content

Commit ef73f51

Browse files
authored
fix(go/plugins/compat_oai): include original request in model response (#3620)
1 parent b2d63ad commit ef73f51

File tree

4 files changed

+93
-31
lines changed

4 files changed

+93
-31
lines changed

go/plugins/compat_oai/compat_oai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ func (o *OpenAICompatible) DefineModel(provider, id string, opts ai.ModelOptions
111111
generator := NewModelGenerator(o.client, modelName).WithMessages(input.Messages).WithConfig(input.Config).WithTools(input.Tools)
112112

113113
// Generate response
114-
resp, err := generator.Generate(ctx, cb)
114+
resp, err := generator.Generate(ctx, input, cb)
115115
if err != nil {
116116
return nil, err
117117
}

go/plugins/compat_oai/generate.go

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
8282

8383
am := openai.ChatCompletionAssistantMessageParam{}
8484
am.Content.OfString = param.NewOpt(content)
85-
toolCalls := convertToolCalls(msg.Content)
85+
toolCalls, err := convertToolCalls(msg.Content)
86+
if err != nil {
87+
g.err = err
88+
return g
89+
}
8690
if len(toolCalls) > 0 {
8791
am.ToolCalls = (toolCalls)
8892
}
@@ -100,10 +104,12 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
100104
toolCallID = p.ToolResponse.Name
101105
}
102106

103-
tm := openai.ToolMessage(
104-
anyToJSONString(p.ToolResponse.Output),
105-
toolCallID,
106-
)
107+
toolOutput, err := anyToJSONString(p.ToolResponse.Output)
108+
if err != nil {
109+
g.err = err
110+
return g
111+
}
112+
tm := openai.ToolMessage(toolOutput, toolCallID)
107113
oaiMessages = append(oaiMessages, tm)
108114
}
109115
case ai.RoleUser:
@@ -210,7 +216,7 @@ func (g *ModelGenerator) WithTools(tools []*ai.ToolDefinition) *ModelGenerator {
210216
}
211217

212218
// Generate executes the generation request
213-
func (g *ModelGenerator) Generate(ctx context.Context, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
219+
func (g *ModelGenerator) Generate(ctx context.Context, req *ai.ModelRequest, handleChunk func(context.Context, *ai.ModelResponseChunk) error) (*ai.ModelResponse, error) {
214220
// Check for any errors that occurred during building
215221
if g.err != nil {
216222
return nil, g.err
@@ -228,7 +234,7 @@ func (g *ModelGenerator) Generate(ctx context.Context, handleChunk func(context.
228234
if handleChunk != nil {
229235
return g.generateStream(ctx, handleChunk)
230236
}
231-
return g.generateComplete(ctx)
237+
return g.generateComplete(ctx, req)
232238
}
233239

234240
// concatenateContent concatenates text content into a single string
@@ -322,11 +328,19 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
322328
if choice.FinishReason == "tool_calls" && currentToolCall != nil {
323329
// parse accumulated arguments string
324330
for _, toolcall := range toolCallCollects {
325-
toolcall.toolCall.Input = jsonStringToMap(toolcall.args)
331+
args, err := jsonStringToMap(toolcall.args)
332+
if err != nil {
333+
return nil, fmt.Errorf("could not parse tool args: %w", err)
334+
}
335+
toolcall.toolCall.Input = args
326336
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(toolcall.toolCall))
327337
}
328338
if currentArguments != "" {
329-
currentToolCall.Input = jsonStringToMap(currentArguments)
339+
args, err := jsonStringToMap(currentArguments)
340+
if err != nil {
341+
return nil, fmt.Errorf("could not parse tool args: %w", err)
342+
}
343+
currentToolCall.Input = args
330344
}
331345
fullResponse.Message.Content = append(fullResponse.Message.Content, ai.NewToolRequestPart(currentToolCall))
332346
}
@@ -356,14 +370,14 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
356370
}
357371

358372
// generateComplete generates a complete model response
359-
func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelResponse, error) {
373+
func (g *ModelGenerator) generateComplete(ctx context.Context, req *ai.ModelRequest) (*ai.ModelResponse, error) {
360374
completion, err := g.client.Chat.Completions.New(ctx, *g.request)
361375
if err != nil {
362376
return nil, fmt.Errorf("failed to create completion: %w", err)
363377
}
364378

365379
resp := &ai.ModelResponse{
366-
Request: &ai.ModelRequest{},
380+
Request: req,
367381
Usage: &ai.GenerationUsage{
368382
InputTokens: int(completion.Usage.PromptTokens),
369383
OutputTokens: int(completion.Usage.CompletionTokens),
@@ -392,10 +406,14 @@ func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelRespons
392406
// handle tool calls
393407
var toolRequestParts []*ai.Part
394408
for _, toolCall := range choice.Message.ToolCalls {
409+
args, err := jsonStringToMap(toolCall.Function.Arguments)
410+
if err != nil {
411+
return nil, err
412+
}
395413
toolRequestParts = append(toolRequestParts, ai.NewToolRequestPart(&ai.ToolRequest{
396414
Ref: toolCall.ID,
397415
Name: toolCall.Function.Name,
398-
Input: jsonStringToMap(toolCall.Function.Arguments),
416+
Input: args,
399417
}))
400418
}
401419

@@ -412,50 +430,57 @@ func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelRespons
412430
return resp, nil
413431
}
414432

415-
func convertToolCalls(content []*ai.Part) []openai.ChatCompletionMessageToolCallParam {
433+
func convertToolCalls(content []*ai.Part) ([]openai.ChatCompletionMessageToolCallParam, error) {
416434
var toolCalls []openai.ChatCompletionMessageToolCallParam
417435
for _, p := range content {
418436
if !p.IsToolRequest() {
419437
continue
420438
}
421-
toolCall := convertToolCall(p)
422-
toolCalls = append(toolCalls, toolCall)
439+
toolCall, err := convertToolCall(p)
440+
if err != nil {
441+
return nil, err
442+
}
443+
toolCalls = append(toolCalls, *toolCall)
423444
}
424-
return toolCalls
445+
return toolCalls, nil
425446
}
426447

427-
func convertToolCall(part *ai.Part) openai.ChatCompletionMessageToolCallParam {
448+
func convertToolCall(part *ai.Part) (*openai.ChatCompletionMessageToolCallParam, error) {
428449
toolCallID := part.ToolRequest.Ref
429450
if toolCallID == "" {
430451
toolCallID = part.ToolRequest.Name
431452
}
432453

433-
param := openai.ChatCompletionMessageToolCallParam{
454+
param := &openai.ChatCompletionMessageToolCallParam{
434455
ID: (toolCallID),
435456
Function: (openai.ChatCompletionMessageToolCallFunctionParam{
436457
Name: (part.ToolRequest.Name),
437458
}),
438459
}
439460

461+
args, err := anyToJSONString(part.ToolRequest.Input)
462+
if err != nil {
463+
return nil, err
464+
}
440465
if part.ToolRequest.Input != nil {
441-
param.Function.Arguments = (anyToJSONString(part.ToolRequest.Input))
466+
param.Function.Arguments = args
442467
}
443468

444-
return param
469+
return param, nil
445470
}
446471

447-
func jsonStringToMap(jsonString string) map[string]any {
472+
func jsonStringToMap(jsonString string) (map[string]any, error) {
448473
var result map[string]any
449474
if err := json.Unmarshal([]byte(jsonString), &result); err != nil {
450-
panic(fmt.Errorf("unmarshal failed to parse json string %s: %w", jsonString, err))
475+
return nil, fmt.Errorf("unmarshal failed to parse json string %s: %w", jsonString, err)
451476
}
452-
return result
477+
return result, nil
453478
}
454479

455-
func anyToJSONString(data any) string {
480+
func anyToJSONString(data any) (string, error) {
456481
jsonBytes, err := json.Marshal(data)
457482
if err != nil {
458-
panic(fmt.Errorf("failed to marshal any to JSON string: data, %#v %w", data, err))
483+
return "", fmt.Errorf("failed to marshal any to JSON string: data, %#v %w", data, err)
459484
}
460-
return string(jsonBytes)
485+
return string(jsonBytes), nil
461486
}

go/plugins/compat_oai/generate_live_test.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,11 @@ func TestGenerator_Complete(t *testing.T) {
6464
},
6565
},
6666
}
67+
req := &ai.ModelRequest{
68+
Messages: messages,
69+
}
6770

68-
resp, err := g.WithMessages(messages).Generate(context.Background(), nil)
71+
resp, err := g.WithMessages(messages).Generate(context.Background(), req, nil)
6972
if err != nil {
7073
t.Error(err)
7174
}
@@ -79,7 +82,6 @@ func TestGenerator_Complete(t *testing.T) {
7982

8083
func TestGenerator_Stream(t *testing.T) {
8184
g := setupTestClient(t)
82-
8385
messages := []*ai.Message{
8486
{
8587
Role: ai.RoleUser,
@@ -88,6 +90,9 @@ func TestGenerator_Stream(t *testing.T) {
8890
},
8991
},
9092
}
93+
req := &ai.ModelRequest{
94+
Messages: messages,
95+
}
9196

9297
var chunks []string
9398
handleChunk := func(ctx context.Context, chunk *ai.ModelResponseChunk) error {
@@ -97,7 +102,7 @@ func TestGenerator_Stream(t *testing.T) {
97102
return nil
98103
}
99104

100-
_, err := g.WithMessages(messages).Generate(context.Background(), handleChunk)
105+
_, err := g.WithMessages(messages).Generate(context.Background(), req, handleChunk)
101106
if err != nil {
102107
t.Error(err)
103108
}
@@ -229,11 +234,14 @@ func TestWithConfig(t *testing.T) {
229234
},
230235
},
231236
}
237+
req := &ai.ModelRequest{
238+
Messages: messages,
239+
}
232240

233241
for _, tt := range tests {
234242
t.Run(tt.name, func(t *testing.T) {
235243
generator := setupTestClient(t)
236-
result, err := generator.WithMessages(messages).WithConfig(tt.config).Generate(context.Background(), nil)
244+
result, err := generator.WithMessages(messages).WithConfig(tt.config).Generate(context.Background(), req, nil)
237245

238246
if tt.err != nil {
239247
assert.Error(t, err)

go/plugins/compat_oai/openai/openai_live_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,4 +250,33 @@ func TestPlugin(t *testing.T) {
250250
}
251251
t.Logf("invalid config type error: %v", err)
252252
})
253+
254+
t.Run("check history", func(t *testing.T) {
255+
resp, err := genkit.Generate(ctx, g,
256+
ai.WithPrompt("Tell me a joke"))
257+
if err != nil {
258+
t.Fatal("got error: %w", err)
259+
}
260+
if resp.Request == nil {
261+
t.Fatal("unexpected nil pointer for request")
262+
}
263+
if len(resp.Request.Messages) == 0 {
264+
t.Fatal("expecting user messages in request")
265+
}
266+
resp, err = genkit.Generate(ctx, g,
267+
ai.WithMessages(resp.History()...),
268+
ai.WithPrompt("explain the joke that you just provided me"))
269+
if err != nil {
270+
t.Fatal("got error: %w", err)
271+
}
272+
userMsgCount := 0
273+
for _, m := range resp.History() {
274+
if m.Role == ai.RoleUser {
275+
userMsgCount += 1
276+
}
277+
}
278+
if userMsgCount != 2 {
279+
t.Fatalf("expecting 2 user messages, got: %d", userMsgCount)
280+
}
281+
})
253282
}

0 commit comments

Comments
 (0)