From 86ee78aa6ac3df62c6ef641a66ffb55073f65cf7 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 8 Oct 2025 23:32:51 -0400 Subject: [PATCH] Add ChatToolMode.RequireSpecific(AITool) In the olden days (i.e. a few months ago), you could require the model/service to request invocation of any tool or of a specific function by name. Now, you can request that it invoke other tools as well. This adds another RequireSpecific overload that takes an AITool instead of a string function name, so that you can do things like RequireSpecific(webSearchTool). --- .../ChatCompletion/ChatToolMode.cs | 18 ++ .../ChatCompletion/RequiredChatToolMode.cs | 64 +++++- .../Microsoft.Extensions.AI.Abstractions.json | 12 + .../OpenAIAssistantsChatClient.cs | 8 + .../OpenAIResponsesChatClient.cs | 6 +- .../ChatCompletion/ChatToolModeTests.cs | 211 +++++++++++++++++- 6 files changed, 306 insertions(+), 13 deletions(-) diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatToolMode.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatToolMode.cs index 73134a5d894..78357774fcd 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatToolMode.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/ChatToolMode.cs @@ -59,5 +59,23 @@ private protected ChatToolMode() /// /// The name of the required function. /// An instance of for the specified function name. + /// + /// Specifying a in a stored + /// into does not automatically include that tool in . + /// The tool must still be provided separately from the . + /// public static RequiredChatToolMode RequireSpecific(string functionName) => new(functionName); + + /// + /// Instantiates a indicating that tool usage is required, + /// and that the specified tool must be selected. + /// + /// The required tool. + /// An instance of for the specified tool. + /// + /// Specifying a in a stored + /// into does not automatically include that tool in . + /// The tool must still be provided separately from the . + /// + public static RequiredChatToolMode RequireSpecific(AITool tool) => new(tool); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs index 899ba04251e..e6cd934a287 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/ChatCompletion/RequiredChatToolMode.cs @@ -3,34 +3,59 @@ using System; using System.Diagnostics; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Extensions.AI; /// -/// Represents a mode where a chat tool must be called. This class can optionally nominate a specific function -/// or indicate that any of the functions can be selected. +/// Represents a mode where a chat tool must be called. This class can optionally nominate a specific tool +/// or indicate that any of the tools can be selected. /// [DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class RequiredChatToolMode : ChatToolMode { /// - /// Gets the name of a specific tool that must be called. + /// Gets the name of a specific function tool that must be called. /// /// - /// If the value is , any available tool can be selected (but at least one must be). + /// If both and are , + /// any available tool can be selected (but at least one must be). /// public string? RequiredFunctionName { get; } + /// Gets the specific tool that must be called. + /// + /// + /// If both and are , + /// any available tool can be selected (but at least one must be). + /// + /// + /// Note that will not serialize to JSON as part of serializing + /// the instance, just as doesn't serialize. As such, attempting to + /// roundtrip a through JSON serialization may lead to the deserialized instance having + /// set to . + /// + /// + [JsonIgnore] + public AITool? RequiredTool { get; } + /// /// Initializes a new instance of the class that requires a specific tool to be called. /// - /// The name of the tool that must be called. + /// The name of the function that must be called. /// is empty or composed entirely of whitespace. /// + /// /// can be . However, it's preferable to use /// when any function can be selected. + /// + /// + /// The specified tool must also be included in the list of tools provided in the request, + /// such as via . + /// /// + [JsonConstructor] public RequiredChatToolMode(string? requiredFunctionName) { if (requiredFunctionName is not null) @@ -41,17 +66,42 @@ public RequiredChatToolMode(string? requiredFunctionName) RequiredFunctionName = requiredFunctionName; } + /// + /// Initializes a new instance of the class that requires a specific tool to be called. + /// + /// The specific tool that must be called. + /// + /// can be . However, it's preferable to use + /// when any function can be selected. + /// + /// + /// Specifying a in a stored + /// into does not automatically include that tool in . + /// The tool must still be provided separately from the . + /// + public RequiredChatToolMode(AITool? requiredTool) + { + if (requiredTool is not null) + { + RequiredTool = requiredTool; + RequiredFunctionName = requiredTool is AIFunctionDeclaration af ? af.Name : null; + } + } + /// Gets a string representing this instance to display in the debugger. [DebuggerBrowsable(DebuggerBrowsableState.Never)] - private string DebuggerDisplay => $"Required: {RequiredFunctionName ?? "Any"}"; + private string DebuggerDisplay => $"Required: {RequiredFunctionName ?? RequiredTool?.Name ?? "Any"}"; /// public override bool Equals(object? obj) => obj is RequiredChatToolMode other && - RequiredFunctionName == other.RequiredFunctionName; + (RequiredFunctionName is not null || other.RequiredFunctionName is not null ? + RequiredFunctionName == other.RequiredFunctionName : + Equals(RequiredTool, other.RequiredTool)); /// public override int GetHashCode() => RequiredFunctionName?.GetHashCode(StringComparison.Ordinal) ?? + RequiredTool?.GetHashCode() ?? typeof(RequiredChatToolMode).GetHashCode(); } diff --git a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json index daa5063cf51..7afb8f98f83 100644 --- a/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json +++ b/src/Libraries/Microsoft.Extensions.AI.Abstractions/Microsoft.Extensions.AI.Abstractions.json @@ -1388,6 +1388,10 @@ { "Member": "static Microsoft.Extensions.AI.RequiredChatToolMode Microsoft.Extensions.AI.ChatToolMode.RequireSpecific(string functionName);", "Stage": "Stable" + }, + { + "Member": "static Microsoft.Extensions.AI.RequiredChatToolMode Microsoft.Extensions.AI.ChatToolMode.RequireSpecific(Microsoft.Extensions.AI.AITool tool);", + "Stage": "Stable" } ], "Properties": [ @@ -2099,6 +2103,10 @@ "Member": "Microsoft.Extensions.AI.RequiredChatToolMode.RequiredChatToolMode(string? requiredFunctionName);", "Stage": "Stable" }, + { + "Member": "Microsoft.Extensions.AI.RequiredChatToolMode.RequiredChatToolMode(Microsoft.Extensions.AI.AITool? requiredTool);", + "Stage": "Stable" + }, { "Member": "override bool Microsoft.Extensions.AI.RequiredChatToolMode.Equals(object? obj);", "Stage": "Stable" @@ -2112,6 +2120,10 @@ { "Member": "string? Microsoft.Extensions.AI.RequiredChatToolMode.RequiredFunctionName { get; }", "Stage": "Stable" + }, + { + "Member": "Microsoft.Extensions.AI.AITool? Microsoft.Extensions.AI.RequiredChatToolMode.RequiredTool { get; }", + "Stage": "Stable" } ] }, diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantsChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantsChatClient.cs index 1de5dd79d4a..5849405f445 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantsChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantsChatClient.cs @@ -397,6 +397,14 @@ internal static FunctionToolDefinition ToOpenAIAssistantsFunctionToolDefinition( runOptions.ToolConstraint = new ToolConstraint(ToolDefinition.CreateFunction(functionName)); break; + case RequiredChatToolMode required when required.RequiredTool is HostedCodeInterpreterTool: + runOptions.ToolConstraint = new ToolConstraint(ToolDefinition.CreateCodeInterpreter()); + break; + + case RequiredChatToolMode required when required.RequiredTool is HostedFileSearchTool: + runOptions.ToolConstraint = new ToolConstraint(ToolDefinition.CreateFileSearch()); + break; + case RequiredChatToolMode required: runOptions.ToolConstraint = ToolConstraint.Required; break; diff --git a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponsesChatClient.cs b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponsesChatClient.cs index 5da26a435ff..cfb2afe9abc 100644 --- a/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponsesChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIResponsesChatClient.cs @@ -531,8 +531,10 @@ private ResponseCreationOptions ToOpenAIResponseCreationOptions(ChatOptions? opt break; case RequiredChatToolMode required: - result.ToolChoice = required.RequiredFunctionName is not null ? - ResponseToolChoice.CreateFunctionChoice(required.RequiredFunctionName) : + result.ToolChoice = + required.RequiredFunctionName is not null ? ResponseToolChoice.CreateFunctionChoice(required.RequiredFunctionName) : + required.RequiredTool is HostedWebSearchTool || required.RequiredTool is ResponseToolAITool { Tool: WebSearchTool } ? ResponseToolChoice.CreateWebSearchChoice() : + required.RequiredTool is HostedFileSearchTool || required.RequiredTool is ResponseToolAITool { Tool: FileSearchTool } ? ResponseToolChoice.CreateFileSearchChoice() : ResponseToolChoice.CreateRequiredChoice(); break; } diff --git a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatToolModeTests.cs b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatToolModeTests.cs index e0c00769277..fd5f551e561 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatToolModeTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Abstractions.Tests/ChatCompletion/ChatToolModeTests.cs @@ -22,14 +22,14 @@ public void Equality_ComparersProduceExpectedResults() Assert.True(ChatToolMode.Auto == ChatToolMode.Auto); Assert.True(ChatToolMode.Auto.Equals(ChatToolMode.Auto)); Assert.False(ChatToolMode.Auto.Equals(ChatToolMode.RequireAny)); - Assert.False(ChatToolMode.Auto.Equals(new RequiredChatToolMode(null))); + Assert.False(ChatToolMode.Auto.Equals(new RequiredChatToolMode((string?)null))); Assert.False(ChatToolMode.Auto.Equals(new RequiredChatToolMode("func"))); Assert.Equal(ChatToolMode.Auto.GetHashCode(), ChatToolMode.Auto.GetHashCode()); Assert.True(ChatToolMode.None == ChatToolMode.None); Assert.True(ChatToolMode.None.Equals(ChatToolMode.None)); Assert.False(ChatToolMode.None.Equals(ChatToolMode.RequireAny)); - Assert.False(ChatToolMode.None.Equals(new RequiredChatToolMode(null))); + Assert.False(ChatToolMode.None.Equals(new RequiredChatToolMode((string?)null))); Assert.False(ChatToolMode.None.Equals(new RequiredChatToolMode("func"))); Assert.Equal(ChatToolMode.None.GetHashCode(), ChatToolMode.None.GetHashCode()); @@ -38,8 +38,8 @@ public void Equality_ComparersProduceExpectedResults() Assert.False(ChatToolMode.RequireAny.Equals(ChatToolMode.Auto)); Assert.False(ChatToolMode.RequireAny.Equals(new RequiredChatToolMode("func"))); - Assert.True(ChatToolMode.RequireAny.Equals(new RequiredChatToolMode(null))); - Assert.Equal(ChatToolMode.RequireAny.GetHashCode(), new RequiredChatToolMode(null).GetHashCode()); + Assert.True(ChatToolMode.RequireAny.Equals(new RequiredChatToolMode((string?)null))); + Assert.Equal(ChatToolMode.RequireAny.GetHashCode(), new RequiredChatToolMode((string?)null).GetHashCode()); Assert.Equal(ChatToolMode.RequireAny.GetHashCode(), ChatToolMode.RequireAny.GetHashCode()); Assert.True(new RequiredChatToolMode("func").Equals(new RequiredChatToolMode("func"))); @@ -91,4 +91,207 @@ public void Serialization_RequireSpecificRoundtrips() ChatToolMode? result = JsonSerializer.Deserialize(json, TestJsonSerializerContext.Default.ChatToolMode); Assert.Equal(ChatToolMode.RequireSpecific("myFunc"), result); } + + [Fact] + public void RequireSpecific_WithAIFunction_UsesCorrectFunctionName() + { + var function = AIFunctionFactory.Create(() => { }, "myFunction"); + + var result = ChatToolMode.RequireSpecific(function); + Assert.IsType(result); + + var requiredMode = Assert.IsType(result); + Assert.Same(function, requiredMode.RequiredTool); + Assert.Equal("myFunction", requiredMode.RequiredFunctionName); + } + + [Fact] + public void RequireSpecific_WithNonFunctionTool_SetsRequiredToolButNullFunctionName() + { + var tool = new TestNonFunctionTool("nonFunctionTool"); + + var result = ChatToolMode.RequireSpecific(tool); + Assert.IsType(result); + + var requiredMode = Assert.IsType(result); + Assert.Same(tool, requiredMode.RequiredTool); + Assert.Null(requiredMode.RequiredFunctionName); + } + + [Fact] + public void RequiredChatToolMode_Constructor_WithAITool_SetsProperties() + { + var tool = AIFunctionFactory.Create(() => { }, "testFunc"); + + RequiredChatToolMode mode = new(tool); + + Assert.Same(tool, mode.RequiredTool); + Assert.Equal("testFunc", mode.RequiredFunctionName); + } + + [Fact] + public void RequiredChatToolMode_Constructor_WithNullAITool_SetsPropertiesCorrectly() + { + RequiredChatToolMode mode = new((AITool?)null); + + Assert.Null(mode.RequiredTool); + Assert.Null(mode.RequiredFunctionName); + } + + [Fact] + public void RequiredChatToolMode_Constructor_WithNonFunctionTool_SetsToolButNullFunctionName() + { + TestNonFunctionTool tool = new("nonFunc"); + + RequiredChatToolMode mode = new(tool); + Assert.Same(tool, mode.RequiredTool); + + Assert.Null(mode.RequiredFunctionName); + } + + [Fact] + public void RequiredChatToolMode_Equals_WithSameAITool_ReturnsTrue() + { + var tool = AIFunctionFactory.Create(() => { }, "testFunc"); + RequiredChatToolMode mode1 = new(tool); + RequiredChatToolMode mode2 = new(tool); + + Assert.True(mode1.Equals(mode2)); + Assert.True(mode2.Equals(mode1)); + Assert.Equal(mode1.GetHashCode(), mode2.GetHashCode()); + } + + [Fact] + public void RequiredChatToolMode_Equals_WithDifferentAITools_ReturnsFalse() + { + RequiredChatToolMode mode1 = new(AIFunctionFactory.Create(() => 42, "func1")); + RequiredChatToolMode mode2 = new(AIFunctionFactory.Create(() => 43, "func2")); + + Assert.False(mode1.Equals(mode2)); + Assert.False(mode2.Equals(mode1)); + } + + [Fact] + public void RequiredChatToolMode_Equals_WithMatchingFunctionNameAndTool_ReturnsTrue() + { + RequiredChatToolMode modeWithTool = new(AIFunctionFactory.Create(() => { }, "func")); + RequiredChatToolMode modeWithFunctionName = new("func"); + + Assert.True(modeWithTool.Equals(modeWithFunctionName)); + Assert.True(modeWithFunctionName.Equals(modeWithTool)); + } + + [Fact] + public void RequiredChatToolMode_Equals_WithNonMatchingFunctionNameAndTool_ReturnsFalse() + { + RequiredChatToolMode modeWithTool = new(AIFunctionFactory.Create(() => { }, "func1")); + RequiredChatToolMode modeWithFunctionName = new("func2"); + + Assert.False(modeWithTool.Equals(modeWithFunctionName)); + Assert.False(modeWithFunctionName.Equals(modeWithTool)); + } + + [Fact] + public void RequiredChatToolMode_Equals_WithBothNull_ReturnsTrue() + { + RequiredChatToolMode mode1 = new((AITool?)null); + RequiredChatToolMode mode2 = new((string?)null); + + Assert.True(mode1.Equals(mode2)); + Assert.True(mode2.Equals(mode1)); + Assert.Equal(mode1.GetHashCode(), mode2.GetHashCode()); + } + + [Fact] + public void RequiredChatToolMode_Equals_WithNullAndSpecific_ReturnsFalse() + { + RequiredChatToolMode modeWithTool = new(AIFunctionFactory.Create(() => { }, "func")); + RequiredChatToolMode modeWithNull = new((AITool?)null); + + Assert.False(modeWithTool.Equals(modeWithNull)); + Assert.False(modeWithNull.Equals(modeWithTool)); + } + + [Fact] + public void RequiredChatToolMode_GetHashCode_ConsistentForSameInstance() + { + RequiredChatToolMode mode = new(AIFunctionFactory.Create(() => { }, "func")); + Assert.Equal(mode.GetHashCode(), mode.GetHashCode()); + } + + [Fact] + public void RequiredChatToolMode_GetHashCode_WithNullTool_ReturnsTypeHashCode() + { + Assert.Equal(typeof(RequiredChatToolMode).GetHashCode(), new RequiredChatToolMode((AITool?)null).GetHashCode()); + } + + [Fact] + public void RequiredChatToolMode_GetHashCode_WithFunctionName_ReturnsStringHashCode() + { + Assert.Equal("testFunc".GetHashCode(), new RequiredChatToolMode("testFunc").GetHashCode()); + } + + [Fact] + public void RequiredChatToolMode_GetHashCode_WithTool_ReturnsToolHashCode() + { + RequiredChatToolMode mode = new(AIFunctionFactory.Create(() => { }, "func")); + Assert.Equal("func".GetHashCode(), mode.GetHashCode()); + } + + [Fact] + public void RequiredChatToolMode_RequiredTool_IsNotSerialized() + { + RequiredChatToolMode mode = new(AIFunctionFactory.Create(() => { }, "func")); + Assert.Equal( + """{"$type":"required","requiredFunctionName":"func"}""", + JsonSerializer.Serialize(mode, TestJsonSerializerContext.Default.ChatToolMode)); + } + + [Fact] + public void RequiredChatToolMode_DeserializationDoesNotRestoreRequiredTool() + { + RequiredChatToolMode originalMode = new(AIFunctionFactory.Create(() => { }, "func")); + + var deserializedMode = JsonSerializer.Deserialize( + JsonSerializer.Serialize(originalMode, TestJsonSerializerContext.Default.ChatToolMode), + TestJsonSerializerContext.Default.ChatToolMode) as RequiredChatToolMode; + + Assert.NotNull(deserializedMode); + Assert.Equal("func", deserializedMode.RequiredFunctionName); + Assert.Null(deserializedMode.RequiredTool); + } + + [Fact] + public void RequiredChatToolMode_Equals_HandlesMixedToolAndFunctionNameScenarios() + { + RequiredChatToolMode modeWithTool1 = new(AIFunctionFactory.Create(() => 42, "sameName")); + RequiredChatToolMode modeWithTool2 = new(AIFunctionFactory.Create(() => 43, "sameName")); + RequiredChatToolMode modeWithFunctionName = new("sameName"); + + Assert.True(modeWithTool1.Equals(modeWithTool2)); + + Assert.True(modeWithTool1.Equals(modeWithFunctionName)); + Assert.True(modeWithTool2.Equals(modeWithFunctionName)); + Assert.True(modeWithFunctionName.Equals(modeWithTool1)); + Assert.True(modeWithFunctionName.Equals(modeWithTool2)); + } + + [Fact] + public void RequiredChatToolMode_Equals_WithNonFunctionTools() + { + TestNonFunctionTool tool1 = new("tool1"); + RequiredChatToolMode mode1 = new(tool1); + RequiredChatToolMode mode2 = new(new TestNonFunctionTool("tool2")); + RequiredChatToolMode mode3 = new(tool1); + + Assert.True(mode1.Equals(mode3)); + Assert.False(mode1.Equals(mode2)); + Assert.Equal(mode1.GetHashCode(), mode3.GetHashCode()); + } + + private sealed class TestNonFunctionTool(string name) : AITool + { + public override string Name => name; + public override string Description => "Non-function tool"; + } }