-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: Conversion of Anthropic Extension to M.E.AI interface
- Loading branch information
Showing
14 changed files
with
1,160 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
204 changes: 204 additions & 0 deletions
204
dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/AnthropicChatCompletionClient.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
Check failure on line 1 in dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/AnthropicChatCompletionClient.cs GitHub Actions / Dotnet Build (macos-latest, 3.11)
|
||
// AnthropicClient.cs | ||
|
||
using System.Buffers.Text; | ||
Check failure on line 4 in dotnet/src/Microsoft.AutoGen/Extensions/Anthropic/AnthropicChatCompletionClient.cs GitHub Actions / Dotnet Build (macos-latest, 3.11)
|
||
using System.Diagnostics.CodeAnalysis; | ||
using Microsoft.AutoGen.Extensions.Anthropic.DTO; | ||
using Microsoft.Extensions.AI; | ||
|
||
using MEAI=Microsoft.Extensions.AI; | ||
|
||
namespace Microsoft.AutoGen.Extensions.Anthropic; | ||
|
||
public static class AnthropicChatCompletionDefaults | ||
{ | ||
//public const string DefaultSystemMessage = "You are a helpful AI assistant"; | ||
public const decimal DefaultTemperature = 0.7m; | ||
public const int DefaultMaxTokens = 1024; | ||
} | ||
|
||
public sealed class AnthropicChatCompletionClient : IChatClient, IDisposable | ||
{ | ||
private AnthropicClient? _anthropicClient; | ||
private string _modelId; | ||
|
||
public AnthropicChatCompletionClient(HttpClient httpClient, string modelId, string baseUrl, string apiKey) | ||
: this(new AnthropicClient(httpClient, baseUrl, apiKey), modelId) | ||
{ | ||
} | ||
|
||
public AnthropicChatCompletionClient([NotNull] AnthropicClient client, string modelId) | ||
{ | ||
if (client == null) | ||
{ | ||
throw new ArgumentNullException(nameof(client)); | ||
} | ||
|
||
_anthropicClient = client; | ||
_modelId = modelId; | ||
|
||
if (!Uri.TryCreate(client.BaseUrl, UriKind.Absolute, out Uri? uri)) | ||
{ | ||
// technically we should never be able to get this far, in this case | ||
throw new ArgumentException($"Invalid base URL in provided client: {client.BaseUrl}", nameof(client)); | ||
} | ||
|
||
this.Metadata = new ChatClientMetadata("Anthropic", uri, modelId); | ||
} | ||
|
||
public ChatClientMetadata Metadata { get; private set; } | ||
|
||
private ContentBase Translate(AIContent content) | ||
{ | ||
if (content is MEAI.TextContent textContent) | ||
{ | ||
return (DTO.TextContent)textContent; | ||
} | ||
else if (content is MEAI.ImageContent imageContent) | ||
{ | ||
return (DTO.ImageContent)imageContent; | ||
} | ||
else if (content is MEAI.FunctionCallContent functionCallContent) | ||
{ | ||
return (DTO.ToolUseContent)functionCallContent; | ||
} | ||
else if (content is MEAI.FunctionResultContent functionResultContent) | ||
{ | ||
return (DTO.ToolResultContent)functionResultContent; | ||
} | ||
// TODO: Enable uage when it is ready | ||
else if (content is MEAI.UsageContent) | ||
{ | ||
throw new NotImplementedException("TODO!"); | ||
} | ||
else | ||
{ | ||
throw new ArgumentException($"Unsupported AIContent type: {content.GetType()}", nameof(content)); | ||
} | ||
} | ||
|
||
private List<ContentBase> Translate(IList<AIContent> content) | ||
{ | ||
return new List<ContentBase>(from rawContent in content select Translate(rawContent)); | ||
} | ||
|
||
private DTO.ChatMessage Translate(MEAI.ChatMessage message, List<SystemMessage>? systemMessagesSink = null) | ||
{ | ||
if (message.Role == ChatRole.System && systemMessagesSink != null) | ||
{ | ||
if (message.Contents.Count != 1 || message.Text == null) | ||
{ | ||
throw new Exception($"Invalid SystemMessage: May only contain a single Text AIContent. Actual: { | ||
String.Join(',', from contentObject in message.Contents select contentObject.GetType()) | ||
}"); | ||
} | ||
|
||
systemMessagesSink.Add(SystemMessage.CreateSystemMessage(message.Text)); | ||
} | ||
|
||
return new DTO.ChatMessage(message.Role.ToString().ToLowerInvariant(), Translate(message.Contents)); | ||
} | ||
|
||
private ChatCompletionRequest Translate(IList<MEAI.ChatMessage> chatMessages, ChatOptions? options, bool requestStream) | ||
{ | ||
ToolChoice? toolChoice = null; | ||
ChatToolMode? mode = options?.ToolMode; | ||
|
||
if (mode is AutoChatToolMode) | ||
{ | ||
toolChoice = ToolChoice.Auto; | ||
} | ||
else if (mode is RequiredChatToolMode requiredToolMode) | ||
{ | ||
if (requiredToolMode.RequiredFunctionName == null) | ||
{ | ||
toolChoice = ToolChoice.Any; | ||
} | ||
else | ||
{ | ||
toolChoice = ToolChoice.ToolUse(requiredToolMode.RequiredFunctionName!); | ||
} | ||
} | ||
|
||
List<SystemMessage> systemMessages = new List<SystemMessage>(); | ||
List<DTO.ChatMessage> translatedMessages = new(); | ||
|
||
foreach (MEAI.ChatMessage message in chatMessages) | ||
{ | ||
if (message.Role == ChatRole.System) | ||
{ | ||
Translate(message, systemMessages); | ||
|
||
// TODO: Should the system messages be included in the translatedMessages list? | ||
} | ||
else | ||
{ | ||
translatedMessages.Add(Translate(message)); | ||
} | ||
} | ||
|
||
return new ChatCompletionRequest | ||
{ | ||
Model = _modelId, | ||
|
||
// TODO: We should consider coming up with a reasonable default for MaxTokens, since the MAAi APIs do not require | ||
// it, while our wrapper for the Anthropic API does. | ||
MaxTokens = options?.MaxOutputTokens ?? throw new ArgumentException("Must specify number of tokens in request for Anthropic", nameof(options)), | ||
StopSequences = options?.StopSequences?.ToArray(), | ||
Stream = requestStream, | ||
Temperature = (decimal?)options?.Temperature, // TODO: why `decimal`?! | ||
ToolChoice = toolChoice, | ||
Tools = (from abstractTool in options?.Tools | ||
where abstractTool is AIFunction | ||
select (Tool)(AIFunction)abstractTool).ToList(), | ||
TopK = options?.TopK, | ||
TopP = (decimal?)options?.TopP, | ||
SystemMessage = systemMessages.ToArray(), | ||
Messages = translatedMessages, | ||
|
||
// TODO: put these somewhere? .Metadata? | ||
//ModelId = _modelId, | ||
//Options = options | ||
}; | ||
} | ||
|
||
private ChatCompletion Translate(ChatCompletionResponse response) | ||
{ | ||
response.Content | ||
|
||
ChatCompletion result = new ChatCompletion() | ||
{ | ||
// WIP | ||
} | ||
} | ||
|
||
public async Task<ChatCompletion> CompleteAsync(IList<Microsoft.Extensions.AI.ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) | ||
{ | ||
|
||
ChatCompletionRequest request = Translate(chatMessages, options, requestStream: false); | ||
ChatCompletionResponse response = await this.EnsureClient().CreateChatCompletionsAsync(request, cancellationToken); | ||
|
||
|
||
return Translate(response); | ||
} | ||
|
||
private AnthropicClient EnsureClient() | ||
{ | ||
return this._anthropicClient ?? throw new ObjectDisposedException(nameof(AnthropicChatCompletionClient)); | ||
} | ||
|
||
public IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(IList<Microsoft.Extensions.AI.ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) | ||
{ | ||
throw new NotImplementedException(); | ||
} | ||
|
||
public void Dispose() | ||
{ | ||
Interlocked.Exchange(ref this._anthropicClient, null)?.Dispose(); | ||
} | ||
|
||
public TService? GetService<TService>(object? key = null) where TService : class | ||
{ | ||
throw new NotImplementedException(); | ||
} | ||
} |
Oops, something went wrong.