@@ -82,7 +82,11 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
82
82
83
83
am := openai.ChatCompletionAssistantMessageParam {}
84
84
am .Content .OfString = param .NewOpt (content )
85
- toolCalls := convertToolCalls (msg .Content )
85
+ toolCalls , err := convertToolCalls (msg .Content )
86
+ if err != nil {
87
+ g .err = err
88
+ return g
89
+ }
86
90
if len (toolCalls ) > 0 {
87
91
am .ToolCalls = (toolCalls )
88
92
}
@@ -100,10 +104,12 @@ func (g *ModelGenerator) WithMessages(messages []*ai.Message) *ModelGenerator {
100
104
toolCallID = p .ToolResponse .Name
101
105
}
102
106
103
- tm := openai .ToolMessage (
104
- anyToJSONString (p .ToolResponse .Output ),
105
- toolCallID ,
106
- )
107
+ toolOutput , err := anyToJSONString (p .ToolResponse .Output )
108
+ if err != nil {
109
+ g .err = err
110
+ return g
111
+ }
112
+ tm := openai .ToolMessage (toolOutput , toolCallID )
107
113
oaiMessages = append (oaiMessages , tm )
108
114
}
109
115
case ai .RoleUser :
@@ -210,7 +216,7 @@ func (g *ModelGenerator) WithTools(tools []*ai.ToolDefinition) *ModelGenerator {
210
216
}
211
217
212
218
// Generate executes the generation request
213
- func (g * ModelGenerator ) Generate (ctx context.Context , handleChunk func (context.Context , * ai.ModelResponseChunk ) error ) (* ai.ModelResponse , error ) {
219
+ func (g * ModelGenerator ) Generate (ctx context.Context , req * ai. ModelRequest , handleChunk func (context.Context , * ai.ModelResponseChunk ) error ) (* ai.ModelResponse , error ) {
214
220
// Check for any errors that occurred during building
215
221
if g .err != nil {
216
222
return nil , g .err
@@ -228,7 +234,7 @@ func (g *ModelGenerator) Generate(ctx context.Context, handleChunk func(context.
228
234
if handleChunk != nil {
229
235
return g .generateStream (ctx , handleChunk )
230
236
}
231
- return g .generateComplete (ctx )
237
+ return g .generateComplete (ctx , req )
232
238
}
233
239
234
240
// concatenateContent concatenates text content into a single string
@@ -322,11 +328,19 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
322
328
if choice .FinishReason == "tool_calls" && currentToolCall != nil {
323
329
// parse accumulated arguments string
324
330
for _ , toolcall := range toolCallCollects {
325
- toolcall .toolCall .Input = jsonStringToMap (toolcall .args )
331
+ args , err := jsonStringToMap (toolcall .args )
332
+ if err != nil {
333
+ return nil , fmt .Errorf ("could not parse tool args: %w" , err )
334
+ }
335
+ toolcall .toolCall .Input = args
326
336
fullResponse .Message .Content = append (fullResponse .Message .Content , ai .NewToolRequestPart (toolcall .toolCall ))
327
337
}
328
338
if currentArguments != "" {
329
- currentToolCall .Input = jsonStringToMap (currentArguments )
339
+ args , err := jsonStringToMap (currentArguments )
340
+ if err != nil {
341
+ return nil , fmt .Errorf ("could not parse tool args: %w" , err )
342
+ }
343
+ currentToolCall .Input = args
330
344
}
331
345
fullResponse .Message .Content = append (fullResponse .Message .Content , ai .NewToolRequestPart (currentToolCall ))
332
346
}
@@ -356,14 +370,14 @@ func (g *ModelGenerator) generateStream(ctx context.Context, handleChunk func(co
356
370
}
357
371
358
372
// generateComplete generates a complete model response
359
- func (g * ModelGenerator ) generateComplete (ctx context.Context ) (* ai.ModelResponse , error ) {
373
+ func (g * ModelGenerator ) generateComplete (ctx context.Context , req * ai. ModelRequest ) (* ai.ModelResponse , error ) {
360
374
completion , err := g .client .Chat .Completions .New (ctx , * g .request )
361
375
if err != nil {
362
376
return nil , fmt .Errorf ("failed to create completion: %w" , err )
363
377
}
364
378
365
379
resp := & ai.ModelResponse {
366
- Request : & ai. ModelRequest {} ,
380
+ Request : req ,
367
381
Usage : & ai.GenerationUsage {
368
382
InputTokens : int (completion .Usage .PromptTokens ),
369
383
OutputTokens : int (completion .Usage .CompletionTokens ),
@@ -392,10 +406,14 @@ func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelRespons
392
406
// handle tool calls
393
407
var toolRequestParts []* ai.Part
394
408
for _ , toolCall := range choice .Message .ToolCalls {
409
+ args , err := jsonStringToMap (toolCall .Function .Arguments )
410
+ if err != nil {
411
+ return nil , err
412
+ }
395
413
toolRequestParts = append (toolRequestParts , ai .NewToolRequestPart (& ai.ToolRequest {
396
414
Ref : toolCall .ID ,
397
415
Name : toolCall .Function .Name ,
398
- Input : jsonStringToMap ( toolCall . Function . Arguments ) ,
416
+ Input : args ,
399
417
}))
400
418
}
401
419
@@ -412,50 +430,57 @@ func (g *ModelGenerator) generateComplete(ctx context.Context) (*ai.ModelRespons
412
430
return resp , nil
413
431
}
414
432
415
- func convertToolCalls (content []* ai.Part ) []openai.ChatCompletionMessageToolCallParam {
433
+ func convertToolCalls (content []* ai.Part ) ( []openai.ChatCompletionMessageToolCallParam , error ) {
416
434
var toolCalls []openai.ChatCompletionMessageToolCallParam
417
435
for _ , p := range content {
418
436
if ! p .IsToolRequest () {
419
437
continue
420
438
}
421
- toolCall := convertToolCall (p )
422
- toolCalls = append (toolCalls , toolCall )
439
+ toolCall , err := convertToolCall (p )
440
+ if err != nil {
441
+ return nil , err
442
+ }
443
+ toolCalls = append (toolCalls , * toolCall )
423
444
}
424
- return toolCalls
445
+ return toolCalls , nil
425
446
}
426
447
427
- func convertToolCall (part * ai.Part ) openai.ChatCompletionMessageToolCallParam {
448
+ func convertToolCall (part * ai.Part ) ( * openai.ChatCompletionMessageToolCallParam , error ) {
428
449
toolCallID := part .ToolRequest .Ref
429
450
if toolCallID == "" {
430
451
toolCallID = part .ToolRequest .Name
431
452
}
432
453
433
- param := openai.ChatCompletionMessageToolCallParam {
454
+ param := & openai.ChatCompletionMessageToolCallParam {
434
455
ID : (toolCallID ),
435
456
Function : (openai.ChatCompletionMessageToolCallFunctionParam {
436
457
Name : (part .ToolRequest .Name ),
437
458
}),
438
459
}
439
460
461
+ args , err := anyToJSONString (part .ToolRequest .Input )
462
+ if err != nil {
463
+ return nil , err
464
+ }
440
465
if part .ToolRequest .Input != nil {
441
- param .Function .Arguments = ( anyToJSONString ( part . ToolRequest . Input ))
466
+ param .Function .Arguments = args
442
467
}
443
468
444
- return param
469
+ return param , nil
445
470
}
446
471
447
- func jsonStringToMap (jsonString string ) map [string ]any {
472
+ func jsonStringToMap (jsonString string ) ( map [string ]any , error ) {
448
473
var result map [string ]any
449
474
if err := json .Unmarshal ([]byte (jsonString ), & result ); err != nil {
450
- panic ( fmt .Errorf ("unmarshal failed to parse json string %s: %w" , jsonString , err ) )
475
+ return nil , fmt .Errorf ("unmarshal failed to parse json string %s: %w" , jsonString , err )
451
476
}
452
- return result
477
+ return result , nil
453
478
}
454
479
455
- func anyToJSONString (data any ) string {
480
+ func anyToJSONString (data any ) ( string , error ) {
456
481
jsonBytes , err := json .Marshal (data )
457
482
if err != nil {
458
- panic ( fmt .Errorf ("failed to marshal any to JSON string: data, %#v %w" , data , err ) )
483
+ return "" , fmt .Errorf ("failed to marshal any to JSON string: data, %#v %w" , data , err )
459
484
}
460
- return string (jsonBytes )
485
+ return string (jsonBytes ), nil
461
486
}
0 commit comments