Skip to content

Commit dbdc5fc

Browse files
authored
fix(go/ai): handle tools passed in Prompt.Execute() (#3601)
1 parent 5259890 commit dbdc5fc

File tree

4 files changed

+115
-19
lines changed

4 files changed

+115
-19
lines changed

go/ai/generate.go

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -383,23 +383,9 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod
383383
modelName = genOpts.Model.Name()
384384
}
385385

386-
var dynamicTools []Tool
387-
tools := make([]string, len(genOpts.Tools))
388-
toolNames := make(map[string]bool)
389-
for i, toolRef := range genOpts.Tools {
390-
name := toolRef.Name()
391-
// Redundant duplicate tool check with GenerateWithRequest otherwise we will panic when we register the dynamic tools.
392-
if toolNames[name] {
393-
return nil, core.NewError(core.INVALID_ARGUMENT, "ai.Generate: duplicate tool %q", name)
394-
}
395-
toolNames[name] = true
396-
tools[i] = name
397-
// Dynamic tools wouldn't have been registered by this point.
398-
if LookupTool(r, name) == nil {
399-
if tool, ok := toolRef.(Tool); ok {
400-
dynamicTools = append(dynamicTools, tool)
401-
}
402-
}
386+
toolNames, dynamicTools, err := resolveUniqueTools(r, genOpts.Tools)
387+
if err != nil {
388+
return nil, err
403389
}
404390

405391
if len(dynamicTools) > 0 {
@@ -477,7 +463,7 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod
477463
actionOpts := &GenerateActionOptions{
478464
Model: modelName,
479465
Messages: messages,
480-
Tools: tools,
466+
Tools: toolNames,
481467
MaxTurns: genOpts.MaxTurns,
482468
Config: genOpts.Config,
483469
ToolChoice: genOpts.ToolChoice,

go/ai/prompt.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,24 +144,31 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod
144144
if modelRef, ok := execOpts.Model.(ModelRef); ok && execOpts.Config == nil {
145145
execOpts.Config = modelRef.Config()
146146
}
147+
147148
if execOpts.Config != nil {
148149
actionOpts.Config = execOpts.Config
149150
}
151+
150152
if len(execOpts.Documents) > 0 {
151153
actionOpts.Docs = execOpts.Documents
152154
}
155+
153156
if execOpts.ToolChoice != "" {
154157
actionOpts.ToolChoice = execOpts.ToolChoice
155158
}
159+
156160
if execOpts.Model != nil {
157161
actionOpts.Model = execOpts.Model.Name()
158162
}
163+
159164
if execOpts.MaxTurns != 0 {
160165
actionOpts.MaxTurns = execOpts.MaxTurns
161166
}
167+
162168
if execOpts.ReturnToolRequests != nil {
163169
actionOpts.ReturnToolRequests = *execOpts.ReturnToolRequests
164170
}
171+
165172
if execOpts.MessagesFn != nil {
166173
m, err := buildVariables(execOpts.Input)
167174
if err != nil {
@@ -180,7 +187,31 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod
180187
}
181188
}
182189

183-
return GenerateWithRequest(ctx, p.registry, actionOpts, execOpts.Middleware, execOpts.Stream)
190+
toolRefs := execOpts.Tools
191+
if len(toolRefs) == 0 {
192+
toolRefs = make([]ToolRef, 0, len(actionOpts.Tools))
193+
for _, toolName := range actionOpts.Tools {
194+
toolRefs = append(toolRefs, ToolName(toolName))
195+
}
196+
}
197+
198+
toolNames, newTools, err := resolveUniqueTools(p.registry, toolRefs)
199+
if err != nil {
200+
return nil, err
201+
}
202+
actionOpts.Tools = toolNames
203+
204+
r := p.registry
205+
if len(newTools) > 0 {
206+
if !r.IsChild() {
207+
r = r.NewChild()
208+
}
209+
for _, t := range newTools {
210+
t.Register(p.registry)
211+
}
212+
}
213+
214+
return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream)
184215
}
185216

186217
// Render renders the prompt template based on user input.

go/ai/prompt_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,61 @@ func TestValidPrompt(t *testing.T) {
536536
},
537537
},
538538
},
539+
{
540+
name: "execute with tools overriding prompt-level tools",
541+
model: model,
542+
config: &GenerationCommonConfig{Temperature: 11},
543+
inputType: HelloPromptInput{},
544+
systemText: "say hello",
545+
promptText: "my name is foo",
546+
tools: []ToolRef{testTool(reg, "promptTool")},
547+
input: HelloPromptInput{Name: "foo"},
548+
executeOptions: []PromptExecuteOption{
549+
WithInput(HelloPromptInput{Name: "foo"}),
550+
WithTools(testTool(reg, "executeOverrideTool")),
551+
},
552+
wantTextOutput: "Echo: system: tool: say hello; my name is foo; ; Bar; ; config: {\n \"temperature\": 11\n}; context: null",
553+
wantGenerated: &ModelRequest{
554+
Config: &GenerationCommonConfig{
555+
Temperature: 11,
556+
},
557+
Output: &ModelOutputConfig{
558+
ContentType: "text/plain",
559+
},
560+
ToolChoice: "required",
561+
Messages: []*Message{
562+
{
563+
Role: RoleSystem,
564+
Content: []*Part{NewTextPart("say hello")},
565+
},
566+
{
567+
Role: RoleUser,
568+
Content: []*Part{NewTextPart("my name is foo")},
569+
},
570+
{
571+
Role: RoleModel,
572+
Content: []*Part{NewToolRequestPart(&ToolRequest{Name: "executeOverrideTool", Input: map[string]any{"Test": "Bar"}})},
573+
},
574+
{
575+
Role: RoleTool,
576+
Content: []*Part{NewToolResponsePart(&ToolResponse{Output: "Bar"})},
577+
},
578+
},
579+
Tools: []*ToolDefinition{
580+
{
581+
Name: "executeOverrideTool",
582+
Description: "use when need to execute a test",
583+
InputSchema: map[string]any{
584+
"additionalProperties": bool(false),
585+
"properties": map[string]any{"Test": map[string]any{"type": string("string")}},
586+
"required": []any{string("Test")},
587+
"type": string("object"),
588+
},
589+
OutputSchema: map[string]any{"type": string("string")},
590+
},
591+
},
592+
},
593+
},
539594
}
540595

541596
cmpPart := func(a, b *Part) bool {

go/ai/tools.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,27 @@ func (t *tool) Restart(p *Part, opts *RestartOptions) *Part {
300300

301301
return newToolReq
302302
}
303+
304+
// resolveUniqueTools resolves the list of tool refs to a list of all tool names and new tools that must be registered.
305+
// Returns an error if there are tool refs with duplicate names.
306+
func resolveUniqueTools(r api.Registry, toolRefs []ToolRef) (toolNames []string, newTools []Tool, err error) {
307+
toolMap := make(map[string]bool)
308+
309+
for _, toolRef := range toolRefs {
310+
name := toolRef.Name()
311+
312+
if toolMap[name] {
313+
return nil, nil, core.NewError(core.INVALID_ARGUMENT, "duplicate tool %q", name)
314+
}
315+
toolMap[name] = true
316+
toolNames = append(toolNames, name)
317+
318+
if LookupTool(r, name) == nil {
319+
if tool, ok := toolRef.(Tool); ok {
320+
newTools = append(newTools, tool)
321+
}
322+
}
323+
}
324+
325+
return toolNames, newTools, nil
326+
}

0 commit comments

Comments
 (0)