Skip to content

Commit e98e962

Browse files
committed
Automatically expand tool groups in FICC
1 parent 8f63547 commit e98e962

File tree

2 files changed

+77
-9
lines changed

2 files changed

+77
-9
lines changed

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ public override async Task<ChatResponse> GetResponseAsync(
284284
bool lastIterationHadConversationId = false; // whether the last iteration's response had a ConversationId set
285285
int consecutiveErrorCount = 0;
286286

287-
(Dictionary<string, AITool>? toolMap, bool anyToolsRequireApproval) = CreateToolsMap(AdditionalTools, options?.Tools); // all available tools, indexed by name
287+
(Dictionary<string, AITool>? toolMap, bool anyToolsRequireApproval) = await CreateToolsMapAsync([AdditionalTools, options?.Tools], cancellationToken); // all available tools, indexed by name
288288

289289
if (HasAnyApprovalContent(originalMessages))
290290
{
@@ -424,7 +424,7 @@ public override async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseA
424424
List<ChatResponseUpdate> updates = []; // updates from the current response
425425
int consecutiveErrorCount = 0;
426426

427-
(Dictionary<string, AITool>? toolMap, bool anyToolsRequireApproval) = CreateToolsMap(AdditionalTools, options?.Tools); // all available tools, indexed by name
427+
(Dictionary<string, AITool>? toolMap, bool anyToolsRequireApproval) = await CreateToolsMapAsync([AdditionalTools, options?.Tools], cancellationToken); // all available tools, indexed by name
428428

429429
// This is a synthetic ID since we're generating the tool messages instead of getting them from
430430
// the underlying provider. When emitting the streamed chunks, it's perfectly valid for us to
@@ -728,26 +728,51 @@ internal static void FixupHistories(
728728
/// The lists of tools to combine into a single dictionary. Tools from later lists are preferred
729729
/// over tools from earlier lists if they have the same name.
730730
/// </param>
731-
private static (Dictionary<string, AITool>? ToolMap, bool AnyRequireApproval) CreateToolsMap(params ReadOnlySpan<IList<AITool>?> toolLists)
731+
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
732+
private static async ValueTask<(Dictionary<string, AITool>? ToolMap, bool AnyRequireApproval)> CreateToolsMapAsync(IList<AITool>?[] toolLists, CancellationToken cancellationToken)
732733
{
733734
Dictionary<string, AITool>? map = null;
734735
bool anyRequireApproval = false;
735736

736737
foreach (var toolList in toolLists)
737738
{
738-
if (toolList?.Count is int count && count > 0)
739+
if (toolList is not null)
740+
{
741+
map ??= [];
742+
var anyInListRequireApproval = await AddToolListAsync(map, toolList, cancellationToken).ConfigureAwait(false);
743+
anyRequireApproval |= anyInListRequireApproval;
744+
}
745+
}
746+
747+
return (map, anyRequireApproval);
748+
749+
static async ValueTask<bool> AddToolListAsync(Dictionary<string, AITool> map, IEnumerable<AITool> tools, CancellationToken cancellationToken)
750+
{
751+
#if NET
752+
if (tools.TryGetNonEnumeratedCount(out var count) && count == 0)
739753
{
740-
map ??= new(StringComparer.Ordinal);
741-
for (int i = 0; i < count; i++)
754+
return false;
755+
}
756+
#endif
757+
var anyRequireApproval = false;
758+
759+
foreach (var tool in tools)
760+
{
761+
if (tool is AIToolGroup toolGroup)
762+
{
763+
var nestedTools = await toolGroup.GetToolsAsync(cancellationToken).ConfigureAwait(false);
764+
var nestedToolsRequireApproval = await AddToolListAsync(map, nestedTools, cancellationToken).ConfigureAwait(false);
765+
anyRequireApproval |= nestedToolsRequireApproval;
766+
}
767+
else
742768
{
743-
AITool tool = toolList[i];
744769
anyRequireApproval |= tool.GetService<ApprovalRequiredAIFunction>() is not null;
745770
map[tool.Name] = tool;
746771
}
747772
}
748-
}
749773

750-
return (map, anyRequireApproval);
774+
return anyRequireApproval;
775+
}
751776
}
752777

753778
/// <summary>

test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,49 @@ public async Task ClonesChatOptionsAndResetContinuationTokenForBackgroundRespons
12321232
Assert.Null(actualChatOptions!.ContinuationToken);
12331233
}
12341234

1235+
[Fact]
1236+
public async Task ToolGroups_GetExpandedAutomatically()
1237+
{
1238+
var innerGroup = AIToolGroup.Create(
1239+
"InnerGroup",
1240+
"Inner group of tools",
1241+
new List<AITool>
1242+
{
1243+
AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"),
1244+
});
1245+
1246+
var outerGroup = AIToolGroup.Create(
1247+
"OuterGroup",
1248+
"Outer group of tools",
1249+
new List<AITool>
1250+
{
1251+
AIFunctionFactory.Create(() => "Result 1", "Func1"),
1252+
innerGroup,
1253+
AIFunctionFactory.Create((int i) => { }, "VoidReturn"),
1254+
});
1255+
1256+
ChatOptions options = new()
1257+
{
1258+
Tools = [outerGroup]
1259+
};
1260+
1261+
List<ChatMessage> plan =
1262+
[
1263+
new ChatMessage(ChatRole.User, "hello"),
1264+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]),
1265+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", result: "Result 1")]),
1266+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary<string, object?> { { "i", 42 } })]),
1267+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", result: "Result 2: 42")]),
1268+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary<string, object?> { { "i", 43 } })]),
1269+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", result: "Success: Function completed.")]),
1270+
new ChatMessage(ChatRole.Assistant, "world"),
1271+
];
1272+
1273+
await InvokeAndAssertAsync(options, plan);
1274+
1275+
await InvokeAndAssertStreamingAsync(options, plan);
1276+
}
1277+
12351278
private sealed class CustomSynchronizationContext : SynchronizationContext
12361279
{
12371280
public override void Post(SendOrPostCallback d, object? state)

0 commit comments

Comments
 (0)