diff --git a/dotnet/README.md b/dotnet/README.md index d036b912dac9..6483a7d035aa 100644 --- a/dotnet/README.md +++ b/dotnet/README.md @@ -77,7 +77,7 @@ requirements and setup instructions. 3. [Running AI prompts from file](./notebooks/02-running-prompts-from-file.ipynb) 4. [Creating Semantic Functions at runtime (i.e. inline functions)](./notebooks/03-semantic-function-inline.ipynb) 5. [Using Kernel Arguments to Build a Chat Experience](./notebooks/04-kernel-arguments-chat.ipynb) -6. [Creating and Executing Plans](./notebooks/05-using-the-planner.ipynb) +6. [Introduction to the Planning/Function Calling](./notebooks/05-using-function-calling.ipynb) 7. [Building Memory with Embeddings](./notebooks/06-memory-and-embeddings.ipynb) 8. [Creating images with DALL-E 3](./notebooks/07-DALL-E-3.ipynb) 9. [Chatting with ChatGPT and Images](./notebooks/08-chatGPT-with-DALL-E-3.ipynb) diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 617f05940d5d..9951f71010ab 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -373,11 +373,22 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.InMemory.UnitTes EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "GettingStartedWithTextSearch", "samples\GettingStartedWithTextSearch\GettingStartedWithTextSearch.csproj", "{16AFA226-E417-490D-9311-9F2099A1EEC8}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "VectorStoreRAG", "samples\Demos\VectorStoreRAG\VectorStoreRAG.csproj", "{28DFAF27-8FF3-4373-AAA4-2A6969C86246}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "VectorStoreRAG", "samples\Demos\VectorStoreRAG\VectorStoreRAG.csproj", "{28DFAF27-8FF3-4373-AAA4-2A6969C86246}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Process.Runtime.Dapr", "src\Experimental\Process.Runtime.Dapr\Process.Runtime.Dapr.csproj", "{9D5B4B53-0E97-42D9-B37E-CD263B6A1892}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ProcessWithDapr", "samples\Demos\ProcessWithDapr\ProcessWithDapr.csproj", "{95163AA2-1ED5-412A-990B-C40B81934BFD}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ProcessWithDapr", "samples\Demos\ProcessWithDapr\ProcessWithDapr.csproj", "{95163AA2-1ED5-412A-990B-C40B81934BFD}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.MongoDB.UnitTests", "src\Connectors\Connectors.MongoDB.UnitTests\Connectors.MongoDB.UnitTests.csproj", "{6F591D05-5F7F-4211-9042-42D8BCE60415}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Memory", "Memory", "{077928EA-2C61-4667-82FC-6A5120B7AC45}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "MongoDB", "MongoDB", "{AF7F68FD-ADB0-4941-90AE-88EAAB53BEEB}" + ProjectSection(SolutionItems) = preProject + src\InternalUtilities\connectors\Memory\MongoDB\MongoDBConstants.cs = src\InternalUtilities\connectors\Memory\MongoDB\MongoDBConstants.cs + src\InternalUtilities\connectors\Memory\MongoDB\MongoDBGenericDataModelMapper.cs = src\InternalUtilities\connectors\Memory\MongoDB\MongoDBGenericDataModelMapper.cs + src\InternalUtilities\connectors\Memory\MongoDB\MongoDBVectorStoreRecordMapper.cs = src\InternalUtilities\connectors\Memory\MongoDB\MongoDBVectorStoreRecordMapper.cs + EndProjectSection EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -992,6 +1003,12 @@ Global {95163AA2-1ED5-412A-990B-C40B81934BFD}.Publish|Any CPU.Build.0 = Debug|Any CPU {95163AA2-1ED5-412A-990B-C40B81934BFD}.Release|Any CPU.ActiveCfg = Release|Any CPU {95163AA2-1ED5-412A-990B-C40B81934BFD}.Release|Any CPU.Build.0 = Release|Any CPU + {6F591D05-5F7F-4211-9042-42D8BCE60415}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6F591D05-5F7F-4211-9042-42D8BCE60415}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6F591D05-5F7F-4211-9042-42D8BCE60415}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {6F591D05-5F7F-4211-9042-42D8BCE60415}.Publish|Any CPU.Build.0 = Debug|Any CPU + {6F591D05-5F7F-4211-9042-42D8BCE60415}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6F591D05-5F7F-4211-9042-42D8BCE60415}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1129,6 +1146,9 @@ Global {28DFAF27-8FF3-4373-AAA4-2A6969C86246} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {9D5B4B53-0E97-42D9-B37E-CD263B6A1892} = {0D8C6358-5DAA-4EA6-A924-C268A9A21BC9} {95163AA2-1ED5-412A-990B-C40B81934BFD} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} + {6F591D05-5F7F-4211-9042-42D8BCE60415} = {5A7028A7-4DDF-4E4F-84A9-37CE8F8D7E89} + {077928EA-2C61-4667-82FC-6A5120B7AC45} = {314A2705-0F70-44B6-8988-C6DF77BDFD42} + {AF7F68FD-ADB0-4941-90AE-88EAAB53BEEB} = {077928EA-2C61-4667-82FC-6A5120B7AC45} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {FBDC56A3-86AD-4323-AA0F-201E59123B83} diff --git a/dotnet/notebooks/05-using-function-calling.ipynb b/dotnet/notebooks/05-using-function-calling.ipynb new file mode 100644 index 000000000000..62acd6c67ca1 --- /dev/null +++ b/dotnet/notebooks/05-using-function-calling.ipynb @@ -0,0 +1,219 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Introduction to the Function Calling\n", + "\n", + "The most powerful feature of chat completion is the ability to call functions from the model. This allows you to create a chat bot that can interact with your existing code, making it possible to automate business processes, create code snippets, and more.\n", + "\n", + "With Semantic Kernel, we simplify the process of using function calling by automatically describing your functions and their parameters to the model and then handling the back-and-forth communication between the model and your code.\n", + "\n", + "Read more about it [here](https://learn.microsoft.com/en-us/semantic-kernel/concepts/ai-services/chat-completion/function-calling)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "dotnet_interactive": { + "language": "csharp" + }, + "polyglot_notebook": { + "kernelName": "csharp" + } + }, + "outputs": [], + "source": [ + "#r \"nuget: Microsoft.SemanticKernel, 1.23.0\"\n", + "\n", + "#!import config/Settings.cs\n", + "#!import config/Utils.cs\n", + "\n", + "using Microsoft.SemanticKernel;\n", + "using Microsoft.SemanticKernel.Connectors.OpenAI;\n", + "using Kernel = Microsoft.SemanticKernel.Kernel;\n", + "\n", + "var builder = Kernel.CreateBuilder();\n", + "\n", + "// Configure AI backend used by the kernel\n", + "var (useAzureOpenAI, model, azureEndpoint, apiKey, orgId) = Settings.LoadFromFile();\n", + "\n", + "if (useAzureOpenAI)\n", + " builder.AddAzureOpenAIChatCompletion(model, azureEndpoint, apiKey);\n", + "else\n", + " builder.AddOpenAIChatCompletion(model, apiKey, orgId);\n", + "\n", + "var kernel = builder.Build();" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setting Up Execution Settings" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using `FunctionChoiceBehavior.Auto()` will enable automatic function calling. There are also other options like `Required` or `None` which allow to control function calling behavior. More information about it can be found [here](https://learn.microsoft.com/en-gb/semantic-kernel/concepts/ai-services/chat-completion/function-calling/function-choice-behaviors?pivots=programming-language-csharp)." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "dotnet_interactive": { + "language": "csharp" + }, + "polyglot_notebook": { + "kernelName": "csharp" + } + }, + "outputs": [], + "source": [ + "#pragma warning disable SKEXP0001\n", + "\n", + "OpenAIPromptExecutionSettings openAIPromptExecutionSettings = new() \n", + "{\n", + " FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()\n", + "};" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Providing plugins to the Kernel\n", + "Function calling needs an information about available plugins/functions. Here we'll import the `SummarizePlugin` and `WriterPlugin` we have defined on disk." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "dotnet_interactive": { + "language": "csharp" + }, + "polyglot_notebook": { + "kernelName": "csharp" + } + }, + "outputs": [], + "source": [ + "var pluginsDirectory = Path.Combine(System.IO.Directory.GetCurrentDirectory(), \"..\", \"..\", \"prompt_template_samples\");\n", + "\n", + "kernel.ImportPluginFromPromptDirectory(Path.Combine(pluginsDirectory, \"SummarizePlugin\"));\n", + "kernel.ImportPluginFromPromptDirectory(Path.Combine(pluginsDirectory, \"WriterPlugin\"));" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define your ASK. What do you want the Kernel to do?" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "dotnet_interactive": { + "language": "csharp" + }, + "polyglot_notebook": { + "kernelName": "csharp" + } + }, + "outputs": [], + "source": [ + "var ask = \"Tomorrow is Valentine's day. I need to come up with a few date ideas. My significant other likes poems so write them in the form of a poem.\";" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since we imported available plugins to Kernel and defined the ask, we can now invoke a prompt with all the provided information. \n", + "\n", + "We can run function calling with Kernel, if we are interested in result only." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "dotnet_interactive": { + "language": "csharp" + }, + "polyglot_notebook": { + "kernelName": "csharp" + } + }, + "outputs": [], + "source": [ + "var result = await kernel.InvokePromptAsync(ask, new(openAIPromptExecutionSettings));\n", + "\n", + "Console.WriteLine(result);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But we can also run it with `IChatCompletionService` to have an access to `ChatHistory` object, which allows us to see which functions were called as part of a function calling process. Note that passing a Kernel as a parameter to `GetChatMessageContentAsync` method is required, since Kernel holds an information about available plugins." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "using Microsoft.SemanticKernel.ChatCompletion;\n", + "\n", + "var chatCompletionService = kernel.GetRequiredService();\n", + "\n", + "var chatHistory = new ChatHistory();\n", + "\n", + "chatHistory.AddUserMessage(ask);\n", + "\n", + "var chatCompletionResult = await chatCompletionService.GetChatMessageContentAsync(chatHistory, openAIPromptExecutionSettings, kernel);\n", + "\n", + "Console.WriteLine($\"Result: {chatCompletionResult}\\n\");\n", + "Console.WriteLine($\"Chat history: {JsonSerializer.Serialize(chatHistory)}\\n\");" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".NET (C#)", + "language": "C#", + "name": ".net-csharp" + }, + "language_info": { + "name": "polyglot-notebook" + }, + "polyglot_notebook": { + "kernelInfo": { + "defaultKernelName": "csharp", + "items": [ + { + "aliases": [], + "name": "csharp" + } + ] + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dotnet/notebooks/05-using-the-planner.ipynb b/dotnet/notebooks/05-using-the-planner.ipynb deleted file mode 100644 index 779f9214645c..000000000000 --- a/dotnet/notebooks/05-using-the-planner.ipynb +++ /dev/null @@ -1,302 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Introduction to the Planner\n", - "\n", - "The Planner is one of the fundamental concepts of the Semantic Kernel. It makes use of the collection of plugins that have been registered to the kernel and using AI, will formulate a plan to execute a given ask.\n", - "\n", - "Read more about it [here](https://aka.ms/sk/concepts/planner)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "dotnet_interactive": { - "language": "csharp" - }, - "polyglot_notebook": { - "kernelName": "csharp" - } - }, - "outputs": [], - "source": [ - "#r \"nuget: Microsoft.SemanticKernel, 1.23.0\"\n", - "#r \"nuget: Microsoft.SemanticKernel.Planners.Handlebars, 1.23.0-preview\"\n", - "\n", - "#!import config/Settings.cs\n", - "#!import config/Utils.cs\n", - "\n", - "using Microsoft.SemanticKernel;\n", - "using Microsoft.SemanticKernel.Connectors.OpenAI;\n", - "using Kernel = Microsoft.SemanticKernel.Kernel;\n", - "\n", - "var builder = Kernel.CreateBuilder();\n", - "\n", - "// Configure AI backend used by the kernel\n", - "var (useAzureOpenAI, model, azureEndpoint, apiKey, orgId) = Settings.LoadFromFile();\n", - "\n", - "if (useAzureOpenAI)\n", - " builder.AddAzureOpenAIChatCompletion(model, azureEndpoint, apiKey);\n", - "else\n", - " builder.AddOpenAIChatCompletion(model, apiKey, orgId);\n", - "\n", - "var kernel = builder.Build();" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Setting Up Handlebars Planner\n", - "Handlebars Planner is located in the `Microsoft.SemanticKernel.Planning.Handlebars` package." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "dotnet_interactive": { - "language": "csharp" - }, - "polyglot_notebook": { - "kernelName": "csharp" - } - }, - "outputs": [], - "source": [ - "using Microsoft.SemanticKernel.Planning.Handlebars;\n", - "\n", - "#pragma warning disable SKEXP0060\n", - "\n", - "var planner = new HandlebarsPlanner();" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Providing plugins to the planner\n", - "The planner needs to know what plugins are available to it. Here we'll import the `SummarizePlugin` and `WriterPlugin` we have defined on disk." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "dotnet_interactive": { - "language": "csharp" - }, - "polyglot_notebook": { - "kernelName": "csharp" - } - }, - "outputs": [], - "source": [ - "var pluginsDirectory = Path.Combine(System.IO.Directory.GetCurrentDirectory(), \"..\", \"..\", \"prompt_template_samples\");\n", - "\n", - "kernel.ImportPluginFromPromptDirectory(Path.Combine(pluginsDirectory, \"SummarizePlugin\"));\n", - "kernel.ImportPluginFromPromptDirectory(Path.Combine(pluginsDirectory, \"WriterPlugin\"));" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Define your ASK. What do you want the Kernel to do?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "dotnet_interactive": { - "language": "csharp" - }, - "polyglot_notebook": { - "kernelName": "csharp" - } - }, - "outputs": [], - "source": [ - "#pragma warning disable SKEXP0060\n", - "\n", - "var ask = \"Tomorrow is Valentine's day. I need to come up with a few date ideas. My significant other likes poems so write them in the form of a poem.\";\n", - "var originalPlan = await planner.CreatePlanAsync(kernel, ask);\n", - "\n", - "Console.WriteLine(\"Original plan:\\n\");\n", - "Console.WriteLine(originalPlan);" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As you can see in the above plan, the Planner has taken the user's ask and converted it into a Plan object detailing how the AI would go about solving this task.\n", - "\n", - "It makes use of the plugins that the Kernel has available to it and determines which functions to call in order to fulfill the user's ask." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's also define an inline plugin and have it be available to the Planner.\n", - "Be sure to give it a function name and plugin name." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "dotnet_interactive": { - "language": "csharp" - }, - "polyglot_notebook": { - "kernelName": "csharp" - } - }, - "outputs": [], - "source": [ - "string skPrompt = \"\"\"\n", - "{{$input}}\n", - "\n", - "Rewrite the above in the style of Shakespeare.\n", - "\"\"\";\n", - "\n", - "var executionSettings = new OpenAIPromptExecutionSettings \n", - "{\n", - " MaxTokens = 2000,\n", - " Temperature = 0.7,\n", - " TopP = 0.5\n", - "};\n", - "\n", - "var shakespeareFunction = kernel.CreateFunctionFromPrompt(skPrompt, executionSettings, \"Shakespeare\");" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's update our ask using this new plugin." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "dotnet_interactive": { - "language": "csharp" - }, - "polyglot_notebook": { - "kernelName": "csharp" - } - }, - "outputs": [], - "source": [ - "#pragma warning disable SKEXP0060\n", - "\n", - "var ask = @\"Tomorrow is Valentine's day. I need to come up with a few date ideas.\n", - "She likes Shakespeare so write using his style. Write them in the form of a poem.\";\n", - "\n", - "var newPlan = await planner.CreatePlanAsync(kernel, ask);\n", - "\n", - "Console.WriteLine(\"Updated plan:\\n\");\n", - "Console.WriteLine(newPlan);" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Executing the plans\n", - "\n", - "Now that we have different plans, let's try to execute them! The Kernel can execute the plan using RunAsync." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "dotnet_interactive": { - "language": "csharp" - }, - "polyglot_notebook": { - "kernelName": "csharp" - } - }, - "outputs": [], - "source": [ - "#pragma warning disable SKEXP0060\n", - "\n", - "var originalPlanResult = await originalPlan.InvokeAsync(kernel, new KernelArguments());\n", - "\n", - "Console.WriteLine(\"Original Plan results:\\n\");\n", - "Console.WriteLine(Utils.WordWrap(originalPlanResult.ToString(), 100));" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now lets execute and print the new plan:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "dotnet_interactive": { - "language": "csharp" - }, - "polyglot_notebook": { - "kernelName": "csharp" - } - }, - "outputs": [], - "source": [ - "#pragma warning disable SKEXP0060\n", - "\n", - "var newPlanResult = await newPlan.InvokeAsync(kernel, new KernelArguments());\n", - "\n", - "Console.WriteLine(\"New Plan results:\\n\");\n", - "Console.WriteLine(newPlanResult);" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".NET (C#)", - "language": "C#", - "name": ".net-csharp" - }, - "language_info": { - "name": "polyglot-notebook" - }, - "polyglot_notebook": { - "kernelInfo": { - "defaultKernelName": "csharp", - "items": [ - { - "aliases": [], - "name": "csharp" - } - ] - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/dotnet/notebooks/README.md b/dotnet/notebooks/README.md index af5503ff2661..83c8d880ebfd 100644 --- a/dotnet/notebooks/README.md +++ b/dotnet/notebooks/README.md @@ -58,7 +58,7 @@ For a quick dive, look at the [getting started notebook](00-getting-started.ipyn 2. [Running AI prompts from file](02-running-prompts-from-file.ipynb) 3. [Creating Semantic Functions at runtime (i.e. inline functions)](03-semantic-function-inline.ipynb) 4. [Using Kernel Arguments to Build a Chat Experience](04-kernel-arguments-chat.ipynb) -5. [Creating and Executing Plans](05-using-the-planner.ipynb) +5. [Introduction to the Planning/Function Calling](05-using-function-calling.ipynb) 6. [Building Memory with Embeddings](06-memory-and-embeddings.ipynb) 7. [Creating images with DALL-E 3](07-DALL-E-3.ipynb) 8. [Chatting with ChatGPT and Images](08-chatGPT-with-DALL-E-3.ipynb) diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs index c697b1fde0dc..6e41eb7f3cb9 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.MongoDB; using MongoDB.Bson; namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; @@ -13,11 +14,11 @@ namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; /// internal static class AzureCosmosDBMongoDBVectorStoreCollectionSearchMapping { - /// Returns index kind specified on vector property or default . - public static string GetVectorPropertyIndexKind(string? indexKind) => !string.IsNullOrWhiteSpace(indexKind) ? indexKind! : AzureCosmosDBMongoDBConstants.DefaultIndexKind; + /// Returns index kind specified on vector property or default . + public static string GetVectorPropertyIndexKind(string? indexKind) => !string.IsNullOrWhiteSpace(indexKind) ? indexKind! : MongoDBConstants.DefaultIndexKind; - /// Returns distance function specified on vector property or default . - public static string GetVectorPropertyDistanceFunction(string? distanceFunction) => !string.IsNullOrWhiteSpace(distanceFunction) ? distanceFunction! : AzureCosmosDBMongoDBConstants.DefaultDistanceFunction; + /// Returns distance function specified on vector property or default . + public static string GetVectorPropertyDistanceFunction(string? distanceFunction) => !string.IsNullOrWhiteSpace(distanceFunction) ? distanceFunction! : MongoDBConstants.DefaultDistanceFunction; /// /// Build Azure CosmosDB MongoDB filter from the provided . diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs index 2f683c73ef92..94f423743d65 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordCollection.cs @@ -8,6 +8,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.MongoDB; using MongoDB.Bson; using MongoDB.Bson.Serialization.Attributes; using MongoDB.Driver; @@ -72,7 +73,7 @@ public AzureCosmosDBMongoDBVectorStoreRecordCollection( // Verify. Verify.NotNull(mongoDatabase); Verify.NotNullOrWhiteSpace(collectionName); - VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.BsonDocumentCustomMapper is not null, AzureCosmosDBMongoDBConstants.SupportedKeyTypes); + VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.BsonDocumentCustomMapper is not null, MongoDBConstants.SupportedKeyTypes); VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); // Assign. @@ -85,7 +86,7 @@ public AzureCosmosDBMongoDBVectorStoreRecordCollection( this._storagePropertyNames = GetStoragePropertyNames(this._propertyReader.Properties, typeof(TRecord)); // Use Mongo reserved key property name as storage key property name - this._storagePropertyNames[this._propertyReader.KeyPropertyName] = AzureCosmosDBMongoDBConstants.MongoReservedKeyPropertyName; + this._storagePropertyNames[this._propertyReader.KeyPropertyName] = MongoDBConstants.MongoReservedKeyPropertyName; this._vectorStoragePropertyNames = this._propertyReader.VectorProperties.Select(property => this._storagePropertyNames[property.DataModelPropertyName]).ToList(); @@ -211,7 +212,7 @@ public Task UpsertAsync(TRecord record, UpsertRecordOptions? options = n OperationName, () => this._mapper.MapFromDataToStorageModel(record)); - var key = storageModel[AzureCosmosDBMongoDBConstants.MongoReservedKeyPropertyName].AsString; + var key = storageModel[MongoDBConstants.MongoReservedKeyPropertyName].AsString; return this.RunOperationAsync(OperationName, async () => { @@ -402,10 +403,10 @@ private async IAsyncEnumerable> EnumerateAndMapSearc } private FilterDefinition GetFilterById(string id) - => Builders.Filter.Eq(document => document[AzureCosmosDBMongoDBConstants.MongoReservedKeyPropertyName], id); + => Builders.Filter.Eq(document => document[MongoDBConstants.MongoReservedKeyPropertyName], id); private FilterDefinition GetFilterByIds(IEnumerable ids) - => Builders.Filter.In(document => document[AzureCosmosDBMongoDBConstants.MongoReservedKeyPropertyName].AsString, ids); + => Builders.Filter.In(document => document[MongoDBConstants.MongoReservedKeyPropertyName].AsString, ids); private async Task InternalCollectionExistsAsync(CancellationToken cancellationToken) { @@ -521,10 +522,10 @@ private IVectorStoreRecordMapper InitializeMapper() if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) { - return (new AzureCosmosDBMongoDBGenericDataModelMapper(this._propertyReader.RecordDefinition) as IVectorStoreRecordMapper)!; + return (new MongoDBGenericDataModelMapper(this._propertyReader.RecordDefinition) as IVectorStoreRecordMapper)!; } - return new AzureCosmosDBMongoDBVectorStoreRecordMapper(this._propertyReader); + return new MongoDBVectorStoreRecordMapper(this._propertyReader); } #endregion diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/Connectors.Memory.AzureCosmosDBMongoDB.csproj b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/Connectors.Memory.AzureCosmosDBMongoDB.csproj index 9a2d68d24132..7f30b060aa2c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/Connectors.Memory.AzureCosmosDBMongoDB.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/Connectors.Memory.AzureCosmosDBMongoDB.csproj @@ -13,6 +13,10 @@ + + + + Semantic Kernel - Azure CosmosDB MongoDB vCore Connector diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/Connectors.Memory.MongoDB.csproj b/dotnet/src/Connectors/Connectors.Memory.MongoDB/Connectors.Memory.MongoDB.csproj index 12b037d1071a..b091931d6e9e 100644 --- a/dotnet/src/Connectors/Connectors.Memory.MongoDB/Connectors.Memory.MongoDB.csproj +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/Connectors.Memory.MongoDB.csproj @@ -5,13 +5,17 @@ Microsoft.SemanticKernel.Connectors.MongoDB $(AssemblyName) net8.0;netstandard2.0 - alpha + preview + + + + Semantic Kernel - MongoDB Connector @@ -26,4 +30,8 @@ + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/IMongoDBVectorStoreRecordCollectionFactory.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/IMongoDBVectorStoreRecordCollectionFactory.cs new file mode 100644 index 000000000000..3226fd9b4cc2 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/IMongoDBVectorStoreRecordCollectionFactory.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using MongoDB.Driver; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +/// +/// Interface for constructing MongoDB instances when using to retrieve these. +/// +public interface IMongoDBVectorStoreRecordCollectionFactory +{ + /// + /// Constructs a new instance of the . + /// + /// The data type of the record key. + /// The data model to use for adding, updating and retrieving data from storage. + /// that can be used to manage the collections in MongoDB. + /// The name of the collection to connect to. + /// An optional record definition that defines the schema of the record type. If not present, attributes on will be used. + /// The new instance of . + IVectorStoreRecordCollection CreateVectorStoreRecordCollection(IMongoDatabase mongoDatabase, string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition) + where TKey : notnull; +} diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBServiceCollectionExtensions.cs new file mode 100644 index 000000000000..b8e89aab82da --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBServiceCollectionExtensions.cs @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.MongoDB; +using Microsoft.SemanticKernel.Http; +using MongoDB.Driver; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods to register MongoDB instances on an . +/// +public static class MongoDBServiceCollectionExtensions +{ + /// + /// Register a MongoDB with the specified service ID + /// and where the MongoDB is retrieved from the dependency injection container. + /// + /// The to register the on. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// Service collection. + public static IServiceCollection AddMongoDBVectorStore( + this IServiceCollection services, + MongoDBVectorStoreOptions? options = default, + string? serviceId = default) + { + // If we are not constructing MongoDatabase, add the IVectorStore as transient, since we + // cannot make assumptions about how MongoDatabase is being managed. + services.AddKeyedTransient( + serviceId, + (sp, obj) => + { + var database = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService(); + + return new MongoDBVectorStore(database, options); + }); + + return services; + } + + /// + /// Register a MongoDB with the specified service ID + /// and where the MongoDB is constructed using the provided and . + /// + /// The to register the on. + /// Connection string required to connect to MongoDB. + /// Database name for MongoDB. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// Service collection. + public static IServiceCollection AddMongoDBVectorStore( + this IServiceCollection services, + string connectionString, + string databaseName, + MongoDBVectorStoreOptions? options = default, + string? serviceId = default) + { + // If we are constructing IMongoDatabase, add the IVectorStore as singleton, since we are managing the lifetime of it, + // and the recommendation from Mongo is to register it with a singleton lifetime. + services.AddKeyedSingleton( + serviceId, + (sp, obj) => + { + var settings = MongoClientSettings.FromConnectionString(connectionString); + settings.ApplicationName = HttpHeaderConstant.Values.UserAgent; + + var mongoClient = new MongoClient(settings); + var database = mongoClient.GetDatabase(databaseName); + + var selectedOptions = options ?? sp.GetService(); + + return new MongoDBVectorStore(database, options); + }); + + return services; + } + + /// + /// Register a MongoDB and with the specified service ID + /// and where the MongoDB is retrieved from the dependency injection container. + /// + /// The type of the record. + /// The to register the on. + /// The name of the collection. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// Service collection. + public static IServiceCollection AddMongoDBVectorStoreRecordCollection( + this IServiceCollection services, + string collectionName, + MongoDBVectorStoreRecordCollectionOptions? options = default, + string? serviceId = default) + { + services.AddKeyedTransient>( + serviceId, + (sp, obj) => + { + var database = sp.GetRequiredService(); + var selectedOptions = options ?? sp.GetService>(); + + return new MongoDBVectorStoreRecordCollection(database, collectionName, selectedOptions); + }); + + AddVectorizedSearch(services, serviceId); + + return services; + } + + /// + /// Register a MongoDB and with the specified service ID + /// and where the MongoDB is constructed using the provided and . + /// + /// The type of the record. + /// The to register the on. + /// The name of the collection. + /// Connection string required to connect to MongoDB. + /// Database name for MongoDB. + /// Optional options to further configure the . + /// An optional service id to use as the service key. + /// Service collection. + public static IServiceCollection AddMongoDBVectorStoreRecordCollection( + this IServiceCollection services, + string collectionName, + string connectionString, + string databaseName, + MongoDBVectorStoreRecordCollectionOptions? options = default, + string? serviceId = default) + { + services.AddKeyedSingleton>( + serviceId, + (sp, obj) => + { + var settings = MongoClientSettings.FromConnectionString(connectionString); + settings.ApplicationName = HttpHeaderConstant.Values.UserAgent; + + var mongoClient = new MongoClient(settings); + var database = mongoClient.GetDatabase(databaseName); + + var selectedOptions = options ?? sp.GetService>(); + + return new MongoDBVectorStoreRecordCollection(database, collectionName, selectedOptions); + }); + + AddVectorizedSearch(services, serviceId); + + return services; + } + + /// + /// Also register the with the given as a . + /// + /// The type of the data model that the collection should contain. + /// The service collection to register on. + /// The service id that the registrations should use. + private static void AddVectorizedSearch(IServiceCollection services, string? serviceId) + { + services.AddKeyedTransient>( + serviceId, + (sp, obj) => + { + return sp.GetRequiredKeyedService>(serviceId); + }); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStore.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStore.cs new file mode 100644 index 000000000000..5d0fdd6781d5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStore.cs @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using Microsoft.Extensions.VectorData; +using MongoDB.Driver; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +/// +/// Class for accessing the list of collections in a MongoDB vector store. +/// +/// +/// This class can be used with collections of any schema type, but requires you to provide schema information when getting a collection. +/// +public sealed class MongoDBVectorStore : IVectorStore +{ + /// that can be used to manage the collections in MongoDB. + private readonly IMongoDatabase _mongoDatabase; + + /// Optional configuration options for this class. + private readonly MongoDBVectorStoreOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// that can be used to manage the collections in MongoDB. + /// Optional configuration options for this class. + public MongoDBVectorStore(IMongoDatabase mongoDatabase, MongoDBVectorStoreOptions? options = default) + { + Verify.NotNull(mongoDatabase); + + this._mongoDatabase = mongoDatabase; + this._options = options ?? new(); + } + + /// + public IVectorStoreRecordCollection GetCollection(string name, VectorStoreRecordDefinition? vectorStoreRecordDefinition = null) + where TKey : notnull + { + if (typeof(TKey) != typeof(string)) + { + throw new NotSupportedException("Only string keys are supported."); + } + + if (this._options.VectorStoreCollectionFactory is not null) + { + return this._options.VectorStoreCollectionFactory.CreateVectorStoreRecordCollection(this._mongoDatabase, name, vectorStoreRecordDefinition); + } + + var recordCollection = new MongoDBVectorStoreRecordCollection( + this._mongoDatabase, + name, + new() { VectorStoreRecordDefinition = vectorStoreRecordDefinition }) as IVectorStoreRecordCollection; + + return recordCollection!; + } + + /// + public async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using var cursor = await this._mongoDatabase + .ListCollectionNamesAsync(cancellationToken: cancellationToken) + .ConfigureAwait(false); + + while (await cursor.MoveNextAsync(cancellationToken).ConfigureAwait(false)) + { + foreach (var name in cursor.Current) + { + yield return name; + } + } + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionCreateMapping.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionCreateMapping.cs new file mode 100644 index 000000000000..36d0b9ad8c1e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionCreateMapping.cs @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using MongoDB.Bson; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +/// +/// Contains mapping helpers to use when creating a collection in MongoDB. +/// +internal static class MongoDBVectorStoreCollectionCreateMapping +{ + /// + /// Returns an array of indexes to create for vector properties. + /// + /// Collection of vector properties for index creation. + /// A dictionary that maps from a property name to the storage name. + public static BsonArray GetVectorIndexFields( + IReadOnlyList vectorProperties, + Dictionary storagePropertyNames) + { + var indexArray = new BsonArray(); + + // Create separate index for each vector property + foreach (var property in vectorProperties) + { + // Use index name same as vector property name with underscore + var vectorPropertyName = storagePropertyNames[property.DataModelPropertyName]; + + var indexDocument = new BsonDocument + { + { "type", "vector" }, + { "numDimensions", property.Dimensions }, + { "path", vectorPropertyName }, + { "similarity", GetDistanceFunction(property.DistanceFunction, vectorPropertyName) }, + }; + + indexArray.Add(indexDocument); + } + + return indexArray; + } + + /// + /// Returns an array of indexes to create for filterable data properties. + /// + /// Collection of data properties for index creation. + /// A dictionary that maps from a property name to the storage name. + public static BsonArray GetFilterableDataIndexFields( + IReadOnlyList dataProperties, + Dictionary storagePropertyNames) + { + var indexArray = new BsonArray(); + + // Create separate index for each data property + foreach (var property in dataProperties) + { + if (property.IsFilterable) + { + // Use index name same as data property name with underscore + var dataPropertyName = storagePropertyNames[property.DataModelPropertyName]; + + var indexDocument = new BsonDocument + { + { "type", "filter" }, + { "path", dataPropertyName }, + }; + + indexArray.Add(indexDocument); + } + } + + return indexArray; + } + + /// + /// More information about MongoDB distance functions here: . + /// + private static string GetDistanceFunction(string? distanceFunction, string vectorPropertyName) + { + var vectorPropertyDistanceFunction = MongoDBVectorStoreCollectionSearchMapping.GetVectorPropertyDistanceFunction(distanceFunction); + + return vectorPropertyDistanceFunction switch + { + DistanceFunction.CosineSimilarity => "cosine", + DistanceFunction.DotProductSimilarity => "dotProduct", + DistanceFunction.EuclideanDistance => "euclidean", + _ => throw new InvalidOperationException($"Distance function '{distanceFunction}' for {nameof(VectorStoreRecordVectorProperty)} '{vectorPropertyName}' is not supported by the MongoDB VectorStore.") + }; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs new file mode 100644 index 000000000000..931b668f535d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreCollectionSearchMapping.cs @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.Extensions.VectorData; +using MongoDB.Bson; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +/// +/// Contains mapping helpers to use when searching for documents using MongoDB. +/// +internal static class MongoDBVectorStoreCollectionSearchMapping +{ + /// Returns distance function specified on vector property or default . + public static string GetVectorPropertyDistanceFunction(string? distanceFunction) => !string.IsNullOrWhiteSpace(distanceFunction) ? distanceFunction! : MongoDBConstants.DefaultDistanceFunction; + + /// + /// Build MongoDB filter from the provided . + /// + /// The to build MongoDB filter from. + /// A dictionary that maps from a property name to the storage name. + /// Thrown when the provided filter type is unsupported. + /// Thrown when property name specified in filter doesn't exist. + public static BsonDocument? BuildFilter( + VectorSearchFilter? vectorSearchFilter, + Dictionary storagePropertyNames) + { + const string EqualOperator = "$eq"; + + var filterClauses = vectorSearchFilter?.FilterClauses.ToList(); + + if (filterClauses is not { Count: > 0 }) + { + return null; + } + + var filter = new BsonDocument(); + + foreach (var filterClause in filterClauses) + { + string propertyName; + BsonValue propertyValue; + string filterOperator; + + if (filterClause is EqualToFilterClause equalToFilterClause) + { + propertyName = equalToFilterClause.FieldName; + propertyValue = BsonValue.Create(equalToFilterClause.Value); + filterOperator = EqualOperator; + } + else + { + throw new NotSupportedException( + $"Unsupported filter clause type '{filterClause.GetType().Name}'. " + + $"Supported filter clause types are: {string.Join(", ", [ + nameof(EqualToFilterClause)])}"); + } + + if (!storagePropertyNames.TryGetValue(propertyName, out var storagePropertyName)) + { + throw new InvalidOperationException($"Property name '{propertyName}' provided as part of the filter clause is not a valid property name."); + } + + if (filter.Contains(storagePropertyName)) + { + if (filter[storagePropertyName] is BsonDocument document && document.Contains(filterOperator)) + { + throw new NotSupportedException( + $"Filter with operator '{filterOperator}' is already added to '{propertyName}' property. " + + "Multiple filters of the same type in the same property are not supported."); + } + + filter[storagePropertyName][filterOperator] = propertyValue; + } + else + { + filter[storagePropertyName] = new BsonDocument() { [filterOperator] = propertyValue }; + } + } + + return filter; + } + + /// Returns search part of the search query. + public static BsonDocument GetSearchQuery( + TVector vector, + string indexName, + string vectorPropertyName, + int limit, + int numCandidates, + BsonDocument? filter) + { + var searchQuery = new BsonDocument + { + { "index", indexName }, + { "queryVector", BsonArray.Create(vector) }, + { "path", vectorPropertyName }, + { "limit", limit }, + { "numCandidates", numCandidates }, + }; + + if (filter is not null) + { + searchQuery["filter"] = filter; + } + + return new BsonDocument + { + { "$vectorSearch", searchQuery } + }; + } + + /// Returns projection part of the search query to return similarity score together with document. + public static BsonDocument GetProjectionQuery(string scorePropertyName, string documentPropertyName) + { + return new BsonDocument + { + { "$project", + new BsonDocument + { + { scorePropertyName, new BsonDocument { { "$meta", "vectorSearchScore" } } }, + { documentPropertyName, "$$ROOT" } + } + } + }; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreOptions.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreOptions.cs new file mode 100644 index 000000000000..56388b2652da --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreOptions.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +/// +/// Options when creating a +/// +public sealed class MongoDBVectorStoreOptions +{ + /// + /// An optional factory to use for constructing instances, if a custom record collection is required. + /// + public IMongoDBVectorStoreRecordCollectionFactory? VectorStoreCollectionFactory { get; init; } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs new file mode 100644 index 000000000000..f27c8a975bc3 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollection.cs @@ -0,0 +1,615 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +/// +/// Service for storing and retrieving vector records, that uses MongoDB as the underlying storage. +/// +/// The data model to use for adding, updating and retrieving data from storage. +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public sealed class MongoDBVectorStoreRecordCollection : IVectorStoreRecordCollection +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix +{ + /// The name of this database for telemetry purposes. + private const string DatabaseName = "MongoDB"; + + /// Property name to be used for search similarity score value. + private const string ScorePropertyName = "similarityScore"; + + /// Property name to be used for search document value. + private const string DocumentPropertyName = "document"; + + /// The default options for vector search. + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + + /// that can be used to manage the collections in MongoDB. + private readonly IMongoDatabase _mongoDatabase; + + /// MongoDB collection to perform record operations. + private readonly IMongoCollection _mongoCollection; + + /// Optional configuration options for this class. + private readonly MongoDBVectorStoreRecordCollectionOptions _options; + + /// Interface for mapping between a storage model, and the consumer record data model. + private readonly IVectorStoreRecordMapper _mapper; + + /// A dictionary that maps from a property name to the storage name that should be used when serializing it for data and vector properties. + private readonly Dictionary _storagePropertyNames; + + /// Collection of vector storage property names. + private readonly List _vectorStoragePropertyNames; + + /// A helper to access property information for the current data model and record definition. + private readonly VectorStoreRecordPropertyReader _propertyReader; + + /// + public string CollectionName { get; } + + /// + /// Initializes a new instance of the class. + /// + /// that can be used to manage the collections in MongoDB. + /// The name of the collection that this will access. + /// Optional configuration options for this class. + public MongoDBVectorStoreRecordCollection( + IMongoDatabase mongoDatabase, + string collectionName, + MongoDBVectorStoreRecordCollectionOptions? options = default) + { + // Verify. + Verify.NotNull(mongoDatabase); + Verify.NotNullOrWhiteSpace(collectionName); + VectorStoreRecordPropertyVerification.VerifyGenericDataModelKeyType(typeof(TRecord), options?.BsonDocumentCustomMapper is not null, MongoDBConstants.SupportedKeyTypes); + VectorStoreRecordPropertyVerification.VerifyGenericDataModelDefinitionSupplied(typeof(TRecord), options?.VectorStoreRecordDefinition is not null); + + // Assign. + this._mongoDatabase = mongoDatabase; + this._mongoCollection = mongoDatabase.GetCollection(collectionName); + this.CollectionName = collectionName; + this._options = options ?? new MongoDBVectorStoreRecordCollectionOptions(); + this._propertyReader = new VectorStoreRecordPropertyReader(typeof(TRecord), this._options.VectorStoreRecordDefinition, new() { RequiresAtLeastOneVector = false, SupportsMultipleKeys = false, SupportsMultipleVectors = true }); + + this._storagePropertyNames = GetStoragePropertyNames(this._propertyReader.Properties, typeof(TRecord)); + + // Use Mongo reserved key property name as storage key property name + this._storagePropertyNames[this._propertyReader.KeyPropertyName] = MongoDBConstants.MongoReservedKeyPropertyName; + + this._vectorStoragePropertyNames = this._propertyReader.VectorProperties.Select(property => this._storagePropertyNames[property.DataModelPropertyName]).ToList(); + + this._mapper = this.InitializeMapper(); + } + + /// + public Task CollectionExistsAsync(CancellationToken cancellationToken = default) + => this.RunOperationAsync("ListCollectionNames", () => this.InternalCollectionExistsAsync(cancellationToken)); + + /// + public async Task CreateCollectionAsync(CancellationToken cancellationToken = default) + { + await this.RunOperationAsync("CreateCollection", + () => this._mongoDatabase.CreateCollectionAsync(this.CollectionName, cancellationToken: cancellationToken)).ConfigureAwait(false); + + await this.RunOperationWithRetryAsync( + "CreateIndexes", + this._options.MaxRetries, + this._options.DelayInMilliseconds, + () => this.CreateIndexesAsync(this.CollectionName, cancellationToken), + cancellationToken).ConfigureAwait(false); + } + + /// + public async Task CreateCollectionIfNotExistsAsync(CancellationToken cancellationToken = default) + { + if (!await this.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) + { + await this.CreateCollectionAsync(cancellationToken).ConfigureAwait(false); + } + } + + /// + public async Task DeleteAsync(string key, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNullOrWhiteSpace(key); + + await this.RunOperationAsync("DeleteOne", () => this._mongoCollection.DeleteOneAsync(this.GetFilterById(key), cancellationToken)) + .ConfigureAwait(false); + } + + /// + public async Task DeleteBatchAsync(IEnumerable keys, DeleteRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + await this.RunOperationAsync("DeleteMany", () => this._mongoCollection.DeleteManyAsync(this.GetFilterByIds(keys), cancellationToken)) + .ConfigureAwait(false); + } + + /// + public Task DeleteCollectionAsync(CancellationToken cancellationToken = default) + => this.RunOperationAsync("DropCollection", () => this._mongoDatabase.DropCollectionAsync(this.CollectionName, cancellationToken)); + + /// + public async Task GetAsync(string key, GetRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNullOrWhiteSpace(key); + + const string OperationName = "Find"; + + var includeVectors = options?.IncludeVectors ?? false; + + var record = await this.RunOperationAsync(OperationName, async () => + { + using var cursor = await this + .FindAsync(this.GetFilterById(key), options, cancellationToken) + .ConfigureAwait(false); + + return await cursor.SingleOrDefaultAsync(cancellationToken).ConfigureAwait(false); + }).ConfigureAwait(false); + + if (record is null) + { + return default; + } + + return VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel(record, new() { IncludeVectors = includeVectors })); + } + + /// + public async IAsyncEnumerable GetBatchAsync( + IEnumerable keys, + GetRecordOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(keys); + + const string OperationName = "Find"; + + using var cursor = await this + .FindAsync(this.GetFilterByIds(keys), options, cancellationToken) + .ConfigureAwait(false); + + while (await cursor.MoveNextAsync(cancellationToken).ConfigureAwait(false)) + { + foreach (var record in cursor.Current) + { + if (record is not null) + { + yield return VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel(record, new())); + } + } + } + } + + /// + public Task UpsertAsync(TRecord record, UpsertRecordOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(record); + + const string OperationName = "ReplaceOne"; + + var replaceOptions = new ReplaceOptions { IsUpsert = true }; + var storageModel = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromDataToStorageModel(record)); + + var key = storageModel[MongoDBConstants.MongoReservedKeyPropertyName].AsString; + + return this.RunOperationAsync(OperationName, async () => + { + await this._mongoCollection + .ReplaceOneAsync(this.GetFilterById(key), storageModel, replaceOptions, cancellationToken) + .ConfigureAwait(false); + + return key; + }); + } + + /// + public async IAsyncEnumerable UpsertBatchAsync( + IEnumerable records, + UpsertRecordOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(records); + + var tasks = records.Select(record => this.UpsertAsync(record, options, cancellationToken)); + var results = await Task.WhenAll(tasks).ConfigureAwait(false); + + foreach (var result in results) + { + if (result is not null) + { + yield return result; + } + } + } + + /// + public async Task> VectorizedSearchAsync( + TVector vector, + VectorSearchOptions? options = null, + CancellationToken cancellationToken = default) + { + Verify.NotNull(vector); + + Array vectorArray = vector switch + { + ReadOnlyMemory memoryFloat => memoryFloat.ToArray(), + ReadOnlyMemory memoryDouble => memoryDouble.ToArray(), + _ => throw new NotSupportedException( + $"The provided vector type {vector.GetType().FullName} is not supported by the MongoDB connector. " + + $"Supported types are: {string.Join(", ", [ + typeof(ReadOnlyMemory).FullName, + typeof(ReadOnlyMemory).FullName])}") + }; + + var searchOptions = options ?? s_defaultVectorSearchOptions; + var vectorProperty = this.GetVectorPropertyForSearch(searchOptions.VectorPropertyName); + + if (vectorProperty is null) + { + throw new InvalidOperationException("The collection does not have any vector properties, so vector search is not possible."); + } + + var vectorPropertyName = this._storagePropertyNames[vectorProperty.DataModelPropertyName]; + + var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter( + searchOptions.Filter, + this._storagePropertyNames); + + // Constructing a query to fetch "skip + top" total items + // to perform skip logic locally, since skip option is not part of API. + var itemsAmount = searchOptions.Skip + searchOptions.Top; + + var numCandidates = this._options.NumCandidates ?? itemsAmount * MongoDBConstants.DefaultNumCandidatesRatio; + + var searchQuery = MongoDBVectorStoreCollectionSearchMapping.GetSearchQuery( + vectorArray, + this._options.VectorIndexName, + vectorPropertyName, + itemsAmount, + numCandidates, + filter); + + var projectionQuery = MongoDBVectorStoreCollectionSearchMapping.GetProjectionQuery( + ScorePropertyName, + DocumentPropertyName); + + BsonDocument[] pipeline = [searchQuery, projectionQuery]; + + return await this.RunOperationWithRetryAsync( + "VectorizedSearch", + this._options.MaxRetries, + this._options.DelayInMilliseconds, + async () => + { + var cursor = await this._mongoCollection + .AggregateAsync(pipeline, cancellationToken: cancellationToken) + .ConfigureAwait(false); + + return new VectorSearchResults(this.EnumerateAndMapSearchResultsAsync(cursor, searchOptions, cancellationToken)); + }, + cancellationToken).ConfigureAwait(false); + } + + #region private + + private async Task CreateIndexesAsync(string collectionName, CancellationToken cancellationToken) + { + var indexCursor = await this._mongoCollection.Indexes.ListAsync(cancellationToken).ConfigureAwait(false); + var indexes = indexCursor.ToList(cancellationToken).Select(index => index["name"].ToString()) ?? []; + + if (indexes.Contains(this._options.VectorIndexName)) + { + // Vector index already exists. + return; + } + + var fieldsArray = new BsonArray(); + + fieldsArray.AddRange(MongoDBVectorStoreCollectionCreateMapping.GetVectorIndexFields( + this._propertyReader.VectorProperties, + this._storagePropertyNames)); + + fieldsArray.AddRange(MongoDBVectorStoreCollectionCreateMapping.GetFilterableDataIndexFields( + this._propertyReader.DataProperties, + this._storagePropertyNames)); + + if (fieldsArray.Count > 0) + { + var indexArray = new BsonArray + { + new BsonDocument + { + { "name", this._options.VectorIndexName }, + { "type", "vectorSearch" }, + { "definition", new BsonDocument { ["fields"] = fieldsArray } }, + } + }; + + var createIndexCommand = new BsonDocument + { + { "createSearchIndexes", collectionName }, + { "indexes", indexArray } + }; + + await this._mongoDatabase.RunCommandAsync(createIndexCommand, cancellationToken: cancellationToken).ConfigureAwait(false); + } + } + + private async Task> FindAsync(FilterDefinition filter, GetRecordOptions? options, CancellationToken cancellationToken) + { + ProjectionDefinitionBuilder projectionBuilder = Builders.Projection; + ProjectionDefinition? projectionDefinition = null; + + var includeVectors = options?.IncludeVectors ?? false; + + if (!includeVectors && this._vectorStoragePropertyNames.Count > 0) + { + foreach (var vectorPropertyName in this._vectorStoragePropertyNames) + { + projectionDefinition = projectionDefinition is not null ? + projectionDefinition.Exclude(vectorPropertyName) : + projectionBuilder.Exclude(vectorPropertyName); + } + } + + var findOptions = projectionDefinition is not null ? + new FindOptions { Projection = projectionDefinition } : + null; + + return await this._mongoCollection.FindAsync(filter, findOptions, cancellationToken).ConfigureAwait(false); + } + + private async IAsyncEnumerable> EnumerateAndMapSearchResultsAsync( + IAsyncCursor cursor, + VectorSearchOptions searchOptions, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + const string OperationName = "Aggregate"; + + var skipCounter = 0; + + while (await cursor.MoveNextAsync(cancellationToken).ConfigureAwait(false)) + { + foreach (var response in cursor.Current) + { + if (skipCounter >= searchOptions.Skip) + { + var score = response[ScorePropertyName].AsDouble; + var record = VectorStoreErrorHandler.RunModelConversion( + DatabaseName, + this.CollectionName, + OperationName, + () => this._mapper.MapFromStorageToDataModel(response[DocumentPropertyName].AsBsonDocument, new() { IncludeVectors = searchOptions.IncludeVectors })); + + yield return new VectorSearchResult(record, score); + } + + skipCounter++; + } + } + } + + private FilterDefinition GetFilterById(string id) + => Builders.Filter.Eq(document => document[MongoDBConstants.MongoReservedKeyPropertyName], id); + + private FilterDefinition GetFilterByIds(IEnumerable ids) + => Builders.Filter.In(document => document[MongoDBConstants.MongoReservedKeyPropertyName].AsString, ids); + + private async Task InternalCollectionExistsAsync(CancellationToken cancellationToken) + { + var filter = new BsonDocument("name", this.CollectionName); + var options = new ListCollectionNamesOptions { Filter = filter }; + + using var cursor = await this._mongoDatabase.ListCollectionNamesAsync(options, cancellationToken: cancellationToken).ConfigureAwait(false); + + return await cursor.AnyAsync(cancellationToken).ConfigureAwait(false); + } + + private async Task RunOperationAsync(string operationName, Func operation) + { + try + { + await operation.Invoke().ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + } + + private async Task RunOperationAsync(string operationName, Func> operation) + { + try + { + return await operation.Invoke().ConfigureAwait(false); + } + catch (Exception ex) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + } + + private async Task RunOperationWithRetryAsync( + string operationName, + int maxRetries, + int delayInMilliseconds, + Func operation, + CancellationToken cancellationToken) + { + var retries = 0; + + while (retries < maxRetries) + { + try + { + await operation.Invoke().ConfigureAwait(false); + return; + } + catch (Exception ex) + { + retries++; + + if (retries >= maxRetries) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + + await Task.Delay(delayInMilliseconds, cancellationToken).ConfigureAwait(false); + } + } + } + + private async Task RunOperationWithRetryAsync( + string operationName, + int maxRetries, + int delayInMilliseconds, + Func> operation, + CancellationToken cancellationToken) + { + var retries = 0; + + while (retries < maxRetries) + { + try + { + return await operation.Invoke().ConfigureAwait(false); + } + catch (Exception ex) + { + retries++; + + if (retries >= maxRetries) + { + throw new VectorStoreOperationException("Call to vector store failed.", ex) + { + VectorStoreType = DatabaseName, + CollectionName = this.CollectionName, + OperationName = operationName + }; + } + + await Task.Delay(delayInMilliseconds, cancellationToken).ConfigureAwait(false); + } + } + + throw new VectorStoreOperationException("Retry logic failed."); + } + + /// + /// Gets storage property names taking into account BSON serialization attributes. + /// + private static Dictionary GetStoragePropertyNames( + IReadOnlyList properties, + Type dataModel) + { + var storagePropertyNames = new Dictionary(); + + foreach (var property in properties) + { + var propertyInfo = dataModel.GetProperty(property.DataModelPropertyName); + string propertyName; + + if (propertyInfo != null) + { + var bsonElementAttribute = propertyInfo.GetCustomAttribute(); + + propertyName = bsonElementAttribute?.ElementName ?? property.DataModelPropertyName; + } + else + { + propertyName = property.DataModelPropertyName; + } + + storagePropertyNames[property.DataModelPropertyName] = propertyName; + } + + return storagePropertyNames; + } + + /// + /// Get vector property to use for a search by using the storage name for the field name from options + /// if available, and falling back to the first vector property in if not. + /// + /// The vector field name. + /// Thrown if the provided field name is not a valid field name. + private VectorStoreRecordVectorProperty? GetVectorPropertyForSearch(string? vectorFieldName) + { + // If vector property name is provided in options, try to find it in schema or throw an exception. + if (!string.IsNullOrWhiteSpace(vectorFieldName)) + { + // Check vector properties by data model property name. + var vectorProperty = this._propertyReader.VectorProperties + .FirstOrDefault(l => l.DataModelPropertyName.Equals(vectorFieldName, StringComparison.Ordinal)); + + if (vectorProperty is not null) + { + return vectorProperty; + } + + throw new InvalidOperationException($"The {typeof(TRecord).FullName} type does not have a vector property named '{vectorFieldName}'."); + } + + // If vector property is not provided in options, return first vector property from schema. + return this._propertyReader.VectorProperty; + } + + /// + /// Returns custom mapper, generic data model mapper or default record mapper. + /// + private IVectorStoreRecordMapper InitializeMapper() + { + if (this._options.BsonDocumentCustomMapper is not null) + { + return this._options.BsonDocumentCustomMapper; + } + + if (typeof(TRecord) == typeof(VectorStoreGenericDataModel)) + { + return (new MongoDBGenericDataModelMapper(this._propertyReader.RecordDefinition) as IVectorStoreRecordMapper)!; + } + + return new MongoDBVectorStoreRecordMapper(this._propertyReader); + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollectionOptions.cs b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollectionOptions.cs new file mode 100644 index 000000000000..97f48a53dfa1 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.MongoDB/MongoDBVectorStoreRecordCollectionOptions.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using MongoDB.Bson; + +namespace Microsoft.SemanticKernel.Connectors.MongoDB; + +/// +/// Options when creating a . +/// +public sealed class MongoDBVectorStoreRecordCollectionOptions +{ + /// + /// Gets or sets an optional custom mapper to use when converting between the data model and the MongoDB BSON object. + /// + public IVectorStoreRecordMapper? BsonDocumentCustomMapper { get; init; } = null; + + /// + /// Gets or sets an optional record definition that defines the schema of the record type. + /// + /// + /// If not provided, the schema will be inferred from the record model class using reflection. + /// In this case, the record model properties must be annotated with the appropriate attributes to indicate their usage. + /// See , and . + /// + public VectorStoreRecordDefinition? VectorStoreRecordDefinition { get; init; } = null; + + /// + /// Vector index name to use. If null, the default "vector_index" name will be used. + /// + public string VectorIndexName { get; init; } = MongoDBConstants.DefaultVectorIndexName; + + /// + /// Number of max retries for vector collection operation. + /// + public int MaxRetries { get; init; } = 5; + + /// + /// Delay in milliseconds between retries for vector collection operation. + /// + public int DelayInMilliseconds { get; init; } = 1_000; + + /// + /// Number of nearest neighbors to use during the vector search. + /// Value must be less than or equal to 10000. + /// Recommended value should be higher than number of documents to return. + /// If not provided, "number of documents * 10" value will be used. + /// + public int? NumCandidates { get; init; } = null; +} diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/.editorconfig b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/.editorconfig new file mode 100644 index 000000000000..394eef685f21 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/.editorconfig @@ -0,0 +1,6 @@ +# Suppressing errors for Test projects under dotnet folder +[*.cs] +dotnet_diagnostic.CA2007.severity = none # Do not directly await a Task +dotnet_diagnostic.VSTHRD111.severity = none # Use .ConfigureAwait(bool) is hidden by default, set to none to prevent IDE from changing on autosave +dotnet_diagnostic.CS1591.severity = none # Missing XML comment for publicly visible type or member +dotnet_diagnostic.IDE1006.severity = warning # Naming rule violations diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/Connectors.MongoDB.UnitTests.csproj b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/Connectors.MongoDB.UnitTests.csproj new file mode 100644 index 000000000000..b8969e21943e --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/Connectors.MongoDB.UnitTests.csproj @@ -0,0 +1,32 @@ + + + + SemanticKernel.Connectors.MongoDB.UnitTests + SemanticKernel.Connectors.MongoDB.UnitTests + net8.0 + true + enable + disable + false + $(NoWarn);SKEXP0001,SKEXP0020,VSTHRD111,CA2007,CS1591 + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBGenericDataModelMapperTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBGenericDataModelMapperTests.cs similarity index 93% rename from dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBGenericDataModelMapperTests.cs rename to dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBGenericDataModelMapperTests.cs index 0f030bed837b..1e19af61a2f4 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBGenericDataModelMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBGenericDataModelMapperTests.cs @@ -4,16 +4,16 @@ using System.Collections.Generic; using System.Linq; using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +using Microsoft.SemanticKernel.Connectors.MongoDB; using MongoDB.Bson; using Xunit; -namespace SemanticKernel.Connectors.AzureCosmosDBMongoDB.UnitTests; +namespace SemanticKernel.Connectors.MongoDB.UnitTests; /// -/// Unit tests for class. +/// Unit tests for class. /// -public sealed class AzureCosmosDBMongoDBGenericDataModelMapperTests +public sealed class MongoDBGenericDataModelMapperTests { private static readonly VectorStoreRecordDefinition s_vectorStoreRecordDefinition = new() { @@ -51,7 +51,7 @@ public sealed class AzureCosmosDBMongoDBGenericDataModelMapperTests public void MapFromDataToStorageModelMapsAllSupportedTypes() { // Arrange - var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); + var sut = new MongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); var dataModel = new VectorStoreGenericDataModel("key") { Data = @@ -137,7 +137,7 @@ public void MapFromDataToStorageModelMapsNullValues() }, }; - var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(vectorStoreRecordDefinition); + var sut = new MongoDBGenericDataModelMapper(vectorStoreRecordDefinition); // Act var storageModel = sut.MapFromDataToStorageModel(dataModel); @@ -152,7 +152,7 @@ public void MapFromDataToStorageModelMapsNullValues() public void MapFromStorageToDataModelMapsAllSupportedTypes() { // Arrange - var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); + var sut = new MongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); var storageModel = new BsonDocument { ["_id"] = "key", @@ -228,7 +228,7 @@ public void MapFromStorageToDataModelMapsNullValues() ["NullableFloatVector"] = BsonNull.Value }; - var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(vectorStoreRecordDefinition); + var sut = new MongoDBGenericDataModelMapper(vectorStoreRecordDefinition); // Act var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); @@ -244,7 +244,7 @@ public void MapFromStorageToDataModelMapsNullValues() public void MapFromStorageToDataModelThrowsForMissingKey() { // Arrange - var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); + var sut = new MongoDBGenericDataModelMapper(s_vectorStoreRecordDefinition); var storageModel = new BsonDocument(); // Act & Assert @@ -267,7 +267,7 @@ public void MapFromDataToStorageModelSkipsMissingProperties() }; var dataModel = new VectorStoreGenericDataModel("key"); - var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(vectorStoreRecordDefinition); + var sut = new MongoDBGenericDataModelMapper(vectorStoreRecordDefinition); // Act var storageModel = sut.MapFromDataToStorageModel(dataModel); @@ -297,7 +297,7 @@ public void MapFromStorageToDataModelSkipsMissingProperties() ["_id"] = "key" }; - var sut = new AzureCosmosDBMongoDBGenericDataModelMapper(vectorStoreRecordDefinition); + var sut = new MongoDBGenericDataModelMapper(vectorStoreRecordDefinition); // Act var dataModel = sut.MapFromStorageToDataModel(storageModel, new StorageToDataModelMapperOptions { IncludeVectors = true }); diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBHotelModel.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBHotelModel.cs new file mode 100644 index 000000000000..46374a5cc408 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBHotelModel.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using MongoDB.Bson.Serialization.Attributes; + +namespace SemanticKernel.Connectors.MongoDB.UnitTests; + +public class MongoDBHotelModel(string hotelId) +{ + /// The key of the record. + [VectorStoreRecordKey] + public string HotelId { get; init; } = hotelId; + + /// A string metadata field. + [VectorStoreRecordData(IsFilterable = true)] + public string? HotelName { get; set; } + + /// An int metadata field. + [VectorStoreRecordData] + public int HotelCode { get; set; } + + /// A float metadata field. + [VectorStoreRecordData] + public float? HotelRating { get; set; } + + /// A bool metadata field. + [BsonElement("parking_is_included")] + [VectorStoreRecordData] + public bool ParkingIncluded { get; set; } + + /// An array metadata field. + [VectorStoreRecordData] + public List Tags { get; set; } = []; + + /// A data field. + [VectorStoreRecordData] + public string? Description { get; set; } + + /// A vector field. + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineSimilarity)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } +} diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..ac6f401583ac --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBServiceCollectionExtensionsTests.cs @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Reflection; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.MongoDB; +using Microsoft.SemanticKernel.Http; +using MongoDB.Driver; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.MongoDB.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class MongoDBServiceCollectionExtensionsTests +{ + private readonly IServiceCollection _serviceCollection = new ServiceCollection(); + + [Fact] + public void AddVectorStoreRegistersClass() + { + // Arrange + this._serviceCollection.AddSingleton(Mock.Of()); + + // Act + this._serviceCollection.AddMongoDBVectorStore(); + + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + var vectorStore = serviceProvider.GetRequiredService(); + + // Assert + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + } + + [Fact] + public void AddVectorStoreWithConnectionStringRegistersClass() + { + // Act + this._serviceCollection.AddMongoDBVectorStore("mongodb://localhost:27017", "mydb"); + + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + var vectorStore = serviceProvider.GetRequiredService(); + + // Assert + Assert.NotNull(vectorStore); + Assert.IsType(vectorStore); + + var database = (IMongoDatabase)vectorStore.GetType().GetField("_mongoDatabase", BindingFlags.NonPublic | BindingFlags.Instance)!.GetValue(vectorStore)!; + Assert.Equal(HttpHeaderConstant.Values.UserAgent, database.Client.Settings.ApplicationName); + } + + [Fact] + public void AddVectorStoreRecordCollectionRegistersClass() + { + // Arrange + this._serviceCollection.AddSingleton(Mock.Of()); + + // Act + this._serviceCollection.AddMongoDBVectorStoreRecordCollection("testcollection"); + + // Assert + this.AssertVectorStoreRecordCollectionCreated(); + } + + [Fact] + public void AddVectorStoreRecordCollectionWithConnectionStringRegistersClass() + { + // Act + this._serviceCollection.AddMongoDBVectorStoreRecordCollection("testcollection", "mongodb://localhost:27017", "mydb"); + + // Assert + this.AssertVectorStoreRecordCollectionCreated(); + } + + private void AssertVectorStoreRecordCollectionCreated() + { + var serviceProvider = this._serviceCollection.BuildServiceProvider(); + + var collection = serviceProvider.GetRequiredService>(); + Assert.NotNull(collection); + Assert.IsType>(collection); + + var vectorizedSearch = serviceProvider.GetRequiredService>(); + Assert.NotNull(vectorizedSearch); + Assert.IsType>(vectorizedSearch); + } + +#pragma warning disable CA1812 // Avoid uninstantiated internal classes + private sealed class TestRecord +#pragma warning restore CA1812 // Avoid uninstantiated internal classes + { + [VectorStoreRecordKey] + public string Id { get; set; } = string.Empty; + } +} diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs new file mode 100644 index 000000000000..8242333ecea5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreCollectionSearchMappingTests.cs @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.MongoDB; +using MongoDB.Bson; +using Xunit; + +namespace SemanticKernel.Connectors.MongoDB.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class MongoDBVectorStoreCollectionSearchMappingTests +{ + private readonly Dictionary _storagePropertyNames = new() + { + ["Property1"] = "property_1", + ["Property2"] = "property_2", + }; + + [Fact] + public void BuildFilterWithNullVectorSearchFilterReturnsNull() + { + // Arrange + VectorSearchFilter? vectorSearchFilter = null; + + // Act + var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); + + // Assert + Assert.Null(filter); + } + + [Fact] + public void BuildFilterWithoutFilterClausesReturnsNull() + { + // Arrange + VectorSearchFilter vectorSearchFilter = new(); + + // Act + var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); + + // Assert + Assert.Null(filter); + } + + [Fact] + public void BuildFilterThrowsExceptionWithUnsupportedFilterClause() + { + // Arrange + var vectorSearchFilter = new VectorSearchFilter().AnyTagEqualTo("NonExistentProperty", "TestValue"); + + // Act & Assert + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + } + + [Fact] + public void BuildFilterThrowsExceptionWithNonExistentPropertyName() + { + // Arrange + var vectorSearchFilter = new VectorSearchFilter().EqualTo("NonExistentProperty", "TestValue"); + + // Act & Assert + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + } + + [Fact] + public void BuildFilterThrowsExceptionWithMultipleFilterClausesOfSameType() + { + // Arrange + var vectorSearchFilter = new VectorSearchFilter() + .EqualTo("Property1", "TestValue1") + .EqualTo("Property1", "TestValue2"); + + // Act & Assert + Assert.Throws(() => MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames)); + } + + [Fact] + public void BuilderFilterByDefaultReturnsValidFilter() + { + // Arrange + var expectedFilter = new BsonDocument() { ["property_1"] = new BsonDocument() { ["$eq"] = "TestValue1" } }; + var vectorSearchFilter = new VectorSearchFilter().EqualTo("Property1", "TestValue1"); + + // Act + var filter = MongoDBVectorStoreCollectionSearchMapping.BuildFilter(vectorSearchFilter, this._storagePropertyNames); + + Assert.Equal(filter.ToJson(), expectedFilter.ToJson()); + } +} diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..26a9b9fb00b7 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordCollectionTests.cs @@ -0,0 +1,843 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.MongoDB; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.MongoDB.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class MongoDBVectorStoreRecordCollectionTests +{ + private readonly Mock _mockMongoDatabase = new(); + private readonly Mock> _mockMongoCollection = new(); + + public MongoDBVectorStoreRecordCollectionTests() + { + this._mockMongoDatabase + .Setup(l => l.GetCollection(It.IsAny(), It.IsAny())) + .Returns(this._mockMongoCollection.Object); + } + + [Fact] + public void ConstructorForModelWithoutKeyThrowsException() + { + // Act & Assert + var exception = Assert.Throws(() => new MongoDBVectorStoreRecordCollection(this._mockMongoDatabase.Object, "collection")); + Assert.Contains("No key property found", exception.Message); + } + + [Fact] + public void ConstructorWithDeclarativeModelInitializesCollection() + { + // Act & Assert + var collection = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection"); + + Assert.NotNull(collection); + } + + [Fact] + public void ConstructorWithImperativeModelInitializesCollection() + { + // Arrange + var definition = new VectorStoreRecordDefinition + { + Properties = [new VectorStoreRecordKeyProperty("Id", typeof(string))] + }; + + // Act + var collection = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection", + new() { VectorStoreRecordDefinition = definition }); + + // Assert + Assert.NotNull(collection); + } + + [Theory] + [MemberData(nameof(CollectionExistsData))] + public async Task CollectionExistsReturnsValidResultAsync(List collections, string collectionName, bool expectedResult) + { + // Arrange + var mockCursor = new Mock>(); + + mockCursor + .Setup(l => l.MoveNextAsync(It.IsAny())) + .ReturnsAsync(true); + + mockCursor + .Setup(l => l.Current) + .Returns(collections); + + this._mockMongoDatabase + .Setup(l => l.ListCollectionNamesAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(mockCursor.Object); + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + collectionName); + + // Act + var actualResult = await sut.CollectionExistsAsync(); + + // Assert + Assert.Equal(expectedResult, actualResult); + } + + [Theory] + [InlineData(true, 0)] + [InlineData(false, 1)] + public async Task CreateCollectionInvokesValidMethodsAsync(bool indexExists, int actualIndexCreations) + { + // Arrange + const string CollectionName = "collection"; + + List indexes = indexExists ? [new BsonDocument { ["name"] = "vector_index" }] : []; + + var mockIndexCursor = new Mock>(); + mockIndexCursor + .SetupSequence(l => l.MoveNext(It.IsAny())) + .Returns(true) + .Returns(false); + + mockIndexCursor + .Setup(l => l.Current) + .Returns(indexes); + + var mockMongoIndexManager = new Mock>(); + + mockMongoIndexManager + .Setup(l => l.ListAsync(It.IsAny())) + .ReturnsAsync(mockIndexCursor.Object); + + this._mockMongoCollection + .Setup(l => l.Indexes) + .Returns(mockMongoIndexManager.Object); + + var sut = new MongoDBVectorStoreRecordCollection(this._mockMongoDatabase.Object, CollectionName); + + // Act + await sut.CreateCollectionAsync(); + + // Assert + this._mockMongoDatabase.Verify(l => l.CreateCollectionAsync( + CollectionName, + It.IsAny(), + It.IsAny()), Times.Once()); + + this._mockMongoDatabase.Verify(l => l.RunCommandAsync( + It.Is>(command => + command.Document["createSearchIndexes"] == CollectionName && + command.Document["indexes"].GetType() == typeof(BsonArray) && + ((BsonArray)command.Document["indexes"]).Count == 1), + It.IsAny(), + It.IsAny()), Times.Exactly(actualIndexCreations)); + } + + [Theory] + [MemberData(nameof(CreateCollectionIfNotExistsData))] + public async Task CreateCollectionIfNotExistsInvokesValidMethodsAsync(List collections, int actualCollectionCreations) + { + // Arrange + const string CollectionName = "collection"; + + var mockCursor = new Mock>(); + mockCursor + .Setup(l => l.MoveNextAsync(It.IsAny())) + .ReturnsAsync(true); + + mockCursor + .Setup(l => l.Current) + .Returns(collections); + + this._mockMongoDatabase + .Setup(l => l.ListCollectionNamesAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(mockCursor.Object); + + var mockIndexCursor = new Mock>(); + mockIndexCursor + .SetupSequence(l => l.MoveNext(It.IsAny())) + .Returns(true) + .Returns(false); + + mockIndexCursor + .Setup(l => l.Current) + .Returns([]); + + var mockMongoIndexManager = new Mock>(); + + mockMongoIndexManager + .Setup(l => l.ListAsync(It.IsAny())) + .ReturnsAsync(mockIndexCursor.Object); + + this._mockMongoCollection + .Setup(l => l.Indexes) + .Returns(mockMongoIndexManager.Object); + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + CollectionName); + + // Act + await sut.CreateCollectionIfNotExistsAsync(); + + // Assert + this._mockMongoDatabase.Verify(l => l.CreateCollectionAsync( + CollectionName, + It.IsAny(), + It.IsAny()), Times.Exactly(actualCollectionCreations)); + } + + [Fact] + public async Task DeleteInvokesValidMethodsAsync() + { + // Arrange + const string RecordKey = "key"; + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection"); + + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + var expectedDefinition = Builders.Filter.Eq(document => document["_id"], RecordKey); + + // Act + await sut.DeleteAsync(RecordKey); + + // Assert + this._mockMongoCollection.Verify(l => l.DeleteOneAsync( + It.Is>(definition => + CompareFilterDefinitions(definition, expectedDefinition, documentSerializer, serializerRegistry)), + It.IsAny()), Times.Once()); + } + + [Fact] + public async Task DeleteBatchInvokesValidMethodsAsync() + { + // Arrange + List recordKeys = ["key1", "key2"]; + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection"); + + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + var expectedDefinition = Builders.Filter.In(document => document["_id"].AsString, recordKeys); + + // Act + await sut.DeleteBatchAsync(recordKeys); + + // Assert + this._mockMongoCollection.Verify(l => l.DeleteManyAsync( + It.Is>(definition => + CompareFilterDefinitions(definition, expectedDefinition, documentSerializer, serializerRegistry)), + It.IsAny()), Times.Once()); + } + + [Fact] + public async Task DeleteCollectionInvokesValidMethodsAsync() + { + // Arrange + const string CollectionName = "collection"; + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + CollectionName); + + // Act + await sut.DeleteCollectionAsync(); + + // Assert + this._mockMongoDatabase.Verify(l => l.DropCollectionAsync( + It.Is(name => name == CollectionName), + It.IsAny()), Times.Once()); + } + + [Fact] + public async Task GetReturnsValidRecordAsync() + { + // Arrange + const string RecordKey = "key"; + + var document = new BsonDocument { ["_id"] = RecordKey, ["HotelName"] = "Test Name" }; + + var mockCursor = new Mock>(); + mockCursor + .Setup(l => l.MoveNextAsync(It.IsAny())) + .ReturnsAsync(true); + + mockCursor + .Setup(l => l.Current) + .Returns([document]); + + this._mockMongoCollection + .Setup(l => l.FindAsync( + It.IsAny>(), + It.IsAny>(), + It.IsAny())) + .ReturnsAsync(mockCursor.Object); + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection"); + + // Act + var result = await sut.GetAsync(RecordKey); + + // Assert + Assert.NotNull(result); + Assert.Equal(RecordKey, result.HotelId); + Assert.Equal("Test Name", result.HotelName); + } + + [Fact] + public async Task GetBatchReturnsValidRecordAsync() + { + // Arrange + var document1 = new BsonDocument { ["_id"] = "key1", ["HotelName"] = "Test Name 1" }; + var document2 = new BsonDocument { ["_id"] = "key2", ["HotelName"] = "Test Name 2" }; + var document3 = new BsonDocument { ["_id"] = "key3", ["HotelName"] = "Test Name 3" }; + + var mockCursor = new Mock>(); + mockCursor + .SetupSequence(l => l.MoveNextAsync(It.IsAny())) + .ReturnsAsync(true) + .ReturnsAsync(false); + + mockCursor + .Setup(l => l.Current) + .Returns([document1, document2, document3]); + + this._mockMongoCollection + .Setup(l => l.FindAsync( + It.IsAny>(), + It.IsAny>(), + It.IsAny())) + .ReturnsAsync(mockCursor.Object); + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection"); + + // Act + var results = await sut.GetBatchAsync(["key1", "key2", "key3"]).ToListAsync(); + + // Assert + Assert.NotNull(results[0]); + Assert.Equal("key1", results[0].HotelId); + Assert.Equal("Test Name 1", results[0].HotelName); + + Assert.NotNull(results[1]); + Assert.Equal("key2", results[1].HotelId); + Assert.Equal("Test Name 2", results[1].HotelName); + + Assert.NotNull(results[2]); + Assert.Equal("key3", results[2].HotelId); + Assert.Equal("Test Name 3", results[2].HotelName); + } + + [Fact] + public async Task UpsertReturnsRecordKeyAsync() + { + // Arrange + var hotel = new MongoDBHotelModel("key") { HotelName = "Test Name" }; + + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + var expectedDefinition = Builders.Filter.Eq(document => document["_id"], "key"); + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection"); + + // Act + var result = await sut.UpsertAsync(hotel); + + // Assert + Assert.Equal("key", result); + + this._mockMongoCollection.Verify(l => l.ReplaceOneAsync( + It.Is>(definition => + CompareFilterDefinitions(definition, expectedDefinition, documentSerializer, serializerRegistry)), + It.Is(document => + document["_id"] == "key" && + document["HotelName"] == "Test Name"), + It.IsAny(), + It.IsAny()), Times.Once()); + } + + [Fact] + public async Task UpsertBatchReturnsRecordKeysAsync() + { + // Arrange + var hotel1 = new MongoDBHotelModel("key1") { HotelName = "Test Name 1" }; + var hotel2 = new MongoDBHotelModel("key2") { HotelName = "Test Name 2" }; + var hotel3 = new MongoDBHotelModel("key3") { HotelName = "Test Name 3" }; + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection"); + + // Act + var results = await sut.UpsertBatchAsync([hotel1, hotel2, hotel3]).ToListAsync(); + + // Assert + Assert.NotNull(results); + Assert.Equal(3, results.Count); + + Assert.Equal("key1", results[0]); + Assert.Equal("key2", results[1]); + Assert.Equal("key3", results[2]); + } + + [Fact] + public async Task UpsertWithModelWorksCorrectlyAsync() + { + var definition = new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Id", typeof(string)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) + } + }; + + await this.TestUpsertWithModelAsync( + dataModel: new TestModel { Id = "key", HotelName = "Test Name" }, + expectedPropertyName: "HotelName", + definition: definition); + } + + [Fact] + public async Task UpsertWithVectorStoreModelWorksCorrectlyAsync() + { + await this.TestUpsertWithModelAsync( + dataModel: new VectorStoreTestModel { Id = "key", HotelName = "Test Name" }, + expectedPropertyName: "HotelName"); + } + + [Fact] + public async Task UpsertWithBsonModelWorksCorrectlyAsync() + { + var definition = new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Id", typeof(string)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) + } + }; + + await this.TestUpsertWithModelAsync( + dataModel: new BsonTestModel { Id = "key", HotelName = "Test Name" }, + expectedPropertyName: "hotel_name", + definition: definition); + } + + [Fact] + public async Task UpsertWithBsonVectorStoreModelWorksCorrectlyAsync() + { + await this.TestUpsertWithModelAsync( + dataModel: new BsonVectorStoreTestModel { Id = "key", HotelName = "Test Name" }, + expectedPropertyName: "hotel_name"); + } + + [Fact] + public async Task UpsertWithBsonVectorStoreWithNameModelWorksCorrectlyAsync() + { + await this.TestUpsertWithModelAsync( + dataModel: new BsonVectorStoreWithNameTestModel { Id = "key", HotelName = "Test Name" }, + expectedPropertyName: "bson_hotel_name"); + } + + [Fact] + public async Task UpsertWithCustomMapperWorksCorrectlyAsync() + { + // Arrange + var hotel = new MongoDBHotelModel("key") { HotelName = "Test Name" }; + + var mockMapper = new Mock>(); + + mockMapper + .Setup(l => l.MapFromDataToStorageModel(It.IsAny())) + .Returns(new BsonDocument { ["_id"] = "key", ["my_name"] = "Test Name" }); + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection", + new() { BsonDocumentCustomMapper = mockMapper.Object }); + + // Act + var result = await sut.UpsertAsync(hotel); + + // Assert + Assert.Equal("key", result); + + this._mockMongoCollection.Verify(l => l.ReplaceOneAsync( + It.IsAny>(), + It.Is(document => + document["_id"] == "key" && + document["my_name"] == "Test Name"), + It.IsAny(), + It.IsAny()), Times.Once()); + } + + [Fact] + public async Task GetWithCustomMapperWorksCorrectlyAsync() + { + // Arrange + const string RecordKey = "key"; + + var document = new BsonDocument { ["_id"] = RecordKey, ["my_name"] = "Test Name" }; + + var mockCursor = new Mock>(); + mockCursor + .Setup(l => l.MoveNextAsync(It.IsAny())) + .ReturnsAsync(true); + + mockCursor + .Setup(l => l.Current) + .Returns([document]); + + this._mockMongoCollection + .Setup(l => l.FindAsync( + It.IsAny>(), + It.IsAny>(), + It.IsAny())) + .ReturnsAsync(mockCursor.Object); + + var mockMapper = new Mock>(); + + mockMapper + .Setup(l => l.MapFromStorageToDataModel(It.IsAny(), It.IsAny())) + .Returns(new MongoDBHotelModel(RecordKey) { HotelName = "Name from mapper" }); + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection", + new() { BsonDocumentCustomMapper = mockMapper.Object }); + + // Act + var result = await sut.GetAsync(RecordKey); + + // Assert + Assert.NotNull(result); + Assert.Equal(RecordKey, result.HotelId); + Assert.Equal("Name from mapper", result.HotelName); + } + + [Theory] + [MemberData(nameof(VectorizedSearchVectorTypeData))] + public async Task VectorizedSearchThrowsExceptionWithInvalidVectorTypeAsync(object vector, bool exceptionExpected) + { + // Arrange + this.MockCollectionForSearch(); + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection"); + + // Act & Assert + if (exceptionExpected) + { + await Assert.ThrowsAsync(async () => await sut.VectorizedSearchAsync(vector)); + } + else + { + var actual = await sut.VectorizedSearchAsync(vector); + + Assert.NotNull(actual); + } + } + + [Theory] + [InlineData(null, "TestEmbedding1", 1, 1)] + [InlineData("", "TestEmbedding1", 2, 2)] + [InlineData("TestEmbedding1", "TestEmbedding1", 3, 3)] + [InlineData("TestEmbedding2", "test_embedding_2", 4, 4)] + public async Task VectorizedSearchUsesValidQueryAsync( + string? vectorPropertyName, + string expectedVectorPropertyName, + int actualTop, + int expectedTop) + { + // Arrange + var vector = new ReadOnlyMemory([1f, 2f, 3f]); + + var expectedSearch = new BsonDocument + { + { "$vectorSearch", + new BsonDocument + { + { "index", "vector_index" }, + { "queryVector", BsonArray.Create(vector.ToArray()) }, + { "path", expectedVectorPropertyName }, + { "limit", expectedTop }, + { "numCandidates", expectedTop * 10 }, + } + } + }; + + var expectedProjection = new BsonDocument + { + { "$project", + new BsonDocument + { + { "similarityScore", new BsonDocument { { "$meta", "vectorSearchScore" } } }, + { "document", "$$ROOT" } + } + } + }; + + this.MockCollectionForSearch(); + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection"); + + // Act + var actual = await sut.VectorizedSearchAsync(vector, new() + { + VectorPropertyName = vectorPropertyName, + Top = actualTop, + }); + + // Assert + Assert.NotNull(await actual.Results.FirstOrDefaultAsync()); + + this._mockMongoCollection.Verify(l => l.AggregateAsync( + It.Is>(pipeline => + this.ComparePipeline(pipeline, expectedSearch, expectedProjection)), + It.IsAny(), + It.IsAny()), Times.Once()); + } + + [Fact] + public async Task VectorizedSearchThrowsExceptionWithNonExistentVectorPropertyNameAsync() + { + // Arrange + this.MockCollectionForSearch(); + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection"); + + var options = new VectorSearchOptions { VectorPropertyName = "non-existent-property" }; + + // Act & Assert + await Assert.ThrowsAsync(async () => await (await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f]), options)).Results.FirstOrDefaultAsync()); + } + + [Fact] + public async Task VectorizedSearchReturnsRecordWithScoreAsync() + { + // Arrange + this.MockCollectionForSearch(); + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection"); + + // Act + var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([1f, 2f, 3f])); + + // Assert + var result = await actual.Results.FirstOrDefaultAsync(); + Assert.NotNull(result); + Assert.Equal("key", result.Record.HotelId); + Assert.Equal("Test Name", result.Record.HotelName); + Assert.Equal(0.99f, result.Score); + } + + public static TheoryData, string, bool> CollectionExistsData => new() + { + { ["collection-2"], "collection-2", true }, + { [], "non-existent-collection", false } + }; + + public static TheoryData, int> CreateCollectionIfNotExistsData => new() + { + { ["collection"], 0 }, + { [], 1 } + }; + + public static TheoryData VectorizedSearchVectorTypeData => new() + { + { new ReadOnlyMemory([1f, 2f, 3f]), false }, + { new ReadOnlyMemory([1f, 2f, 3f]), false }, + { new ReadOnlyMemory?(new([1f, 2f, 3f])), false }, + { new ReadOnlyMemory?(new([1f, 2f, 3f])), false }, + { new List([1f, 2f, 3f]), true }, + }; + + #region private + + private bool ComparePipeline( + PipelineDefinition actualPipeline, + BsonDocument expectedSearch, + BsonDocument expectedProjection) + { + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + + var documents = actualPipeline.Render(new RenderArgs(documentSerializer, serializerRegistry)).Documents; + + return + documents[0].ToJson() == expectedSearch.ToJson() && + documents[1].ToJson() == expectedProjection.ToJson(); + } + + private void MockCollectionForSearch() + { + var document = new BsonDocument { ["_id"] = "key", ["HotelName"] = "Test Name" }; + var searchResult = new BsonDocument { ["document"] = document, ["similarityScore"] = 0.99f }; + + var mockCursor = new Mock>(); + mockCursor + .Setup(l => l.MoveNextAsync(It.IsAny())) + .ReturnsAsync(true); + + mockCursor + .Setup(l => l.Current) + .Returns([searchResult]); + + this._mockMongoCollection + .Setup(l => l.AggregateAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .ReturnsAsync(mockCursor.Object); + } + + private async Task TestUpsertWithModelAsync( + TDataModel dataModel, + string expectedPropertyName, + VectorStoreRecordDefinition? definition = null) + where TDataModel : class + { + // Arrange + var serializerRegistry = BsonSerializer.SerializerRegistry; + var documentSerializer = serializerRegistry.GetSerializer(); + var expectedDefinition = Builders.Filter.Eq(document => document["_id"], "key"); + + MongoDBVectorStoreRecordCollectionOptions? options = definition != null ? + new() { VectorStoreRecordDefinition = definition } : + null; + + var sut = new MongoDBVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection", + options); + + // Act + var result = await sut.UpsertAsync(dataModel); + + // Assert + Assert.Equal("key", result); + + this._mockMongoCollection.Verify(l => l.ReplaceOneAsync( + It.Is>(definition => + CompareFilterDefinitions(definition, expectedDefinition, documentSerializer, serializerRegistry)), + It.Is(document => + document["_id"] == "key" && + document.Contains(expectedPropertyName) && + document[expectedPropertyName] == "Test Name"), + It.IsAny(), + It.IsAny()), Times.Once()); + } + + private static bool CompareFilterDefinitions( + FilterDefinition actual, + FilterDefinition expected, + IBsonSerializer documentSerializer, + IBsonSerializerRegistry serializerRegistry) + { + return actual.Render(new RenderArgs(documentSerializer, serializerRegistry)) == + expected.Render(new RenderArgs(documentSerializer, serializerRegistry)); + } + +#pragma warning disable CA1812 + private sealed class TestModel + { + public string? Id { get; set; } + + public string? HotelName { get; set; } + } + + private sealed class VectorStoreTestModel + { + [VectorStoreRecordKey] + public string? Id { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "hotel_name")] + public string? HotelName { get; set; } + } + + private sealed class BsonTestModel + { + [BsonId] + public string? Id { get; set; } + + [BsonElement("hotel_name")] + public string? HotelName { get; set; } + } + + private sealed class BsonVectorStoreTestModel + { + [BsonId] + [VectorStoreRecordKey] + public string? Id { get; set; } + + [BsonElement("hotel_name")] + [VectorStoreRecordData] + public string? HotelName { get; set; } + } + + private sealed class BsonVectorStoreWithNameTestModel + { + [BsonId] + [VectorStoreRecordKey] + public string? Id { get; set; } + + [BsonElement("bson_hotel_name")] + [VectorStoreRecordData(StoragePropertyName = "storage_hotel_name")] + public string? HotelName { get; set; } + } + + private sealed class VectorSearchModel + { + [BsonId] + [VectorStoreRecordKey] + public string? Id { get; set; } + + [VectorStoreRecordData] + public string? HotelName { get; set; } + + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance, IndexKind: IndexKind.IvfFlat, StoragePropertyName = "test_embedding_1")] + public ReadOnlyMemory TestEmbedding1 { get; set; } + + [BsonElement("test_embedding_2")] + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineDistance, IndexKind: IndexKind.IvfFlat)] + public ReadOnlyMemory TestEmbedding2 { get; set; } + } +#pragma warning restore CA1812 + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordMapperTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordMapperTests.cs similarity index 81% rename from dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordMapperTests.cs rename to dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordMapperTests.cs index 01e5787bad6c..65ccefcc6eee 100644 --- a/dotnet/src/Connectors/Connectors.AzureCosmosDBMongoDB.UnitTests/AzureCosmosDBMongoDBVectorStoreRecordMapperTests.cs +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreRecordMapperTests.cs @@ -3,20 +3,20 @@ using System; using System.Collections.Generic; using Microsoft.Extensions.VectorData; -using Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +using Microsoft.SemanticKernel.Connectors.MongoDB; using MongoDB.Bson; using Xunit; -namespace SemanticKernel.Connectors.AzureCosmosDBMongoDB.UnitTests; +namespace SemanticKernel.Connectors.MongoDB.UnitTests; /// -/// Unit tests for class. +/// Unit tests for class. /// -public sealed class AzureCosmosDBMongoDBVectorStoreRecordMapperTests +public sealed class MongoDBVectorStoreRecordMapperTests { - private readonly AzureCosmosDBMongoDBVectorStoreRecordMapper _sut; + private readonly MongoDBVectorStoreRecordMapper _sut; - public AzureCosmosDBMongoDBVectorStoreRecordMapperTests() + public MongoDBVectorStoreRecordMapperTests() { var keyProperty = new VectorStoreRecordKeyProperty("HotelId", typeof(string)); @@ -32,14 +32,14 @@ public AzureCosmosDBMongoDBVectorStoreRecordMapperTests() ] }; - this._sut = new(new VectorStoreRecordPropertyReader(typeof(AzureCosmosDBMongoDBHotelModel), definition, null)); + this._sut = new(new VectorStoreRecordPropertyReader(typeof(MongoDBHotelModel), definition, null)); } [Fact] public void MapFromDataToStorageModelReturnsValidObject() { // Arrange - var hotel = new AzureCosmosDBMongoDBHotelModel("key") + var hotel = new MongoDBHotelModel("key") { HotelName = "Test Name", Tags = ["tag1", "tag2"], diff --git a/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreTests.cs b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreTests.cs new file mode 100644 index 000000000000..a6be91ac04cc --- /dev/null +++ b/dotnet/src/Connectors/Connectors.MongoDB.UnitTests/MongoDBVectorStoreTests.cs @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.MongoDB; +using MongoDB.Driver; +using Moq; +using Xunit; + +namespace SemanticKernel.Connectors.MongoDB.UnitTests; + +/// +/// Unit tests for class. +/// +public sealed class MongoDBVectorStoreTests +{ + private readonly Mock _mockMongoDatabase = new(); + + [Fact] + public void GetCollectionWithNotSupportedKeyThrowsException() + { + // Arrange + var sut = new MongoDBVectorStore(this._mockMongoDatabase.Object); + + // Act & Assert + Assert.Throws(() => sut.GetCollection("collection")); + } + + [Fact] + public void GetCollectionWithFactoryReturnsCustomCollection() + { + // Arrange + var mockFactory = new Mock(); + var mockRecordCollection = new Mock>(); + + mockFactory + .Setup(l => l.CreateVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection", + It.IsAny())) + .Returns(mockRecordCollection.Object); + + var sut = new MongoDBVectorStore( + this._mockMongoDatabase.Object, + new MongoDBVectorStoreOptions { VectorStoreCollectionFactory = mockFactory.Object }); + + // Act + var collection = sut.GetCollection("collection"); + + // Assert + Assert.Same(mockRecordCollection.Object, collection); + mockFactory.Verify(l => l.CreateVectorStoreRecordCollection( + this._mockMongoDatabase.Object, + "collection", + It.IsAny()), Times.Once()); + } + + [Fact] + public void GetCollectionWithoutFactoryReturnsDefaultCollection() + { + // Arrange + var sut = new MongoDBVectorStore(this._mockMongoDatabase.Object); + + // Act + var collection = sut.GetCollection("collection"); + + // Assert + Assert.NotNull(collection); + } + + [Fact] + public async Task ListCollectionNamesReturnsCollectionNamesAsync() + { + // Arrange + var expectedCollectionNames = new List { "collection-1", "collection-2", "collection-3" }; + + var mockCursor = new Mock>(); + mockCursor + .SetupSequence(l => l.MoveNextAsync(It.IsAny())) + .ReturnsAsync(true) + .ReturnsAsync(false); + + mockCursor + .Setup(l => l.Current) + .Returns(expectedCollectionNames); + + this._mockMongoDatabase + .Setup(l => l.ListCollectionNamesAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(mockCursor.Object); + + var sut = new MongoDBVectorStore(this._mockMongoDatabase.Object); + + // Act + var actualCollectionNames = await sut.ListCollectionNamesAsync().ToListAsync(); + + // Assert + Assert.Equal(expectedCollectionNames, actualCollectionNames); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreCollectionFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreCollectionFixture.cs new file mode 100644 index 000000000000..7defbeec1f5c --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreCollectionFixture.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; + +[CollectionDefinition("MongoDBVectorStoreCollection")] +public class MongoDBVectorStoreCollectionFixture : ICollectionFixture +{ } diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreFixture.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreFixture.cs new file mode 100644 index 000000000000..6c037c70e11b --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreFixture.cs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Docker.DotNet; +using Docker.DotNet.Models; +using Microsoft.Extensions.VectorData; +using MongoDB.Driver; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; + +public class MongoDBVectorStoreFixture : IAsyncLifetime +{ + private readonly List _testCollections = ["sk-test-hotels", "sk-test-contacts", "sk-test-addresses"]; + + /// Main test collection for tests. + public string TestCollection => this._testCollections[0]; + + /// that can be used to manage the collections in MongoDB. + public IMongoDatabase MongoDatabase { get; } + + /// Gets the manually created vector store record definition for MongoDB test model. + public VectorStoreRecordDefinition HotelVectorStoreRecordDefinition { get; private set; } + + /// The id of the MongoDB container that we are testing with. + private string? _containerId = null; + + /// The Docker client we are using to create a MongoDB container with. + private readonly DockerClient _client; + + /// + /// Initializes a new instance of the class. + /// + public MongoDBVectorStoreFixture() + { + using var dockerClientConfiguration = new DockerClientConfiguration(); + this._client = dockerClientConfiguration.CreateClient(); + + var mongoClient = new MongoClient("mongodb://localhost:27017/?directConnection=true"); + + this.MongoDatabase = mongoClient.GetDatabase("test"); + + this.HotelVectorStoreRecordDefinition = new() + { + Properties = + [ + new VectorStoreRecordKeyProperty("HotelId", typeof(string)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)), + new VectorStoreRecordDataProperty("HotelCode", typeof(int)), + new VectorStoreRecordDataProperty("ParkingIncluded", typeof(bool)) { StoragePropertyName = "parking_is_included" }, + new VectorStoreRecordDataProperty("HotelRating", typeof(float)), + new VectorStoreRecordDataProperty("Tags", typeof(List)), + new VectorStoreRecordDataProperty("Timestamp", typeof(DateTime)), + new VectorStoreRecordDataProperty("Description", typeof(string)), + new VectorStoreRecordVectorProperty("DescriptionEmbedding", typeof(ReadOnlyMemory?)) { Dimensions = 4, IndexKind = IndexKind.IvfFlat, DistanceFunction = DistanceFunction.CosineDistance } + ] + }; + } + + public async Task InitializeAsync() + { + this._containerId = await SetupMongoDBContainerAsync(this._client); + + foreach (var collection in this._testCollections) + { + await this.MongoDatabase.CreateCollectionAsync(collection); + } + } + + public async Task DisposeAsync() + { + var cursor = await this.MongoDatabase.ListCollectionNamesAsync(); + + while (await cursor.MoveNextAsync().ConfigureAwait(false)) + { + foreach (var collection in cursor.Current) + { + await this.MongoDatabase.DropCollectionAsync(collection); + } + } + + if (this._containerId != null) + { + await this._client.Containers.StopContainerAsync(this._containerId, new ContainerStopParameters()); + await this._client.Containers.RemoveContainerAsync(this._containerId, new ContainerRemoveParameters()); + } + } + +#pragma warning disable CS8618 + public record MongoDBHotel() + { + /// The key of the record. + [VectorStoreRecordKey] + public string HotelId { get; init; } + + /// A string metadata field. + [VectorStoreRecordData(IsFilterable = true)] + public string? HotelName { get; set; } + + /// An int metadata field. + [VectorStoreRecordData] + public int HotelCode { get; set; } + + /// A float metadata field. + [VectorStoreRecordData] + public float? HotelRating { get; set; } + + /// A bool metadata field. + [VectorStoreRecordData(StoragePropertyName = "parking_is_included")] + public bool ParkingIncluded { get; set; } + + /// An array metadata field. + [VectorStoreRecordData] + public List Tags { get; set; } = []; + + /// A data field. + [VectorStoreRecordData] + public string Description { get; set; } + + /// A datetime metadata field. + [VectorStoreRecordData] + public DateTime Timestamp { get; set; } + + /// A vector field. + [VectorStoreRecordVector(Dimensions: 4, DistanceFunction: DistanceFunction.CosineSimilarity, IndexKind: IndexKind.IvfFlat)] + public ReadOnlyMemory? DescriptionEmbedding { get; set; } + } +#pragma warning restore CS8618 + + #region private + + private static async Task SetupMongoDBContainerAsync(DockerClient client) + { + const string Image = "mongodb/mongodb-atlas-local"; + const string Tag = "latest"; + + await client.Images.CreateImageAsync( + new ImagesCreateParameters + { + FromImage = Image, + Tag = Tag, + }, + null, + new Progress()); + + var container = await client.Containers.CreateContainerAsync(new CreateContainerParameters() + { + Image = $"{Image}:{Tag}", + HostConfig = new HostConfig() + { + PortBindings = new Dictionary> + { + { "27017", new List { new() { HostPort = "27017" } } }, + }, + PublishAllPorts = true + }, + ExposedPorts = new Dictionary + { + { "27017", default }, + }, + }); + + await client.Containers.StartContainerAsync( + container.ID, + new ContainerStartParameters()); + + return container.ID; + } + + #endregion +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs new file mode 100644 index 000000000000..b603448f1adc --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreRecordCollectionTests.cs @@ -0,0 +1,534 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel.Connectors.MongoDB; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using Xunit; +using static SemanticKernel.IntegrationTests.Connectors.MongoDB.MongoDBVectorStoreFixture; + +namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; + +[Collection("MongoDBVectorStoreCollection")] +public class MongoDBVectorStoreRecordCollectionTests(MongoDBVectorStoreFixture fixture) +{ + [Theory] + [InlineData("sk-test-hotels", true)] + [InlineData("nonexistentcollection", false)] + public async Task CollectionExistsReturnsCollectionStateAsync(string collectionName, bool expectedExists) + { + // Arrange + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, collectionName); + + // Act + var actual = await sut.CollectionExistsAsync(); + + // Assert + Assert.Equal(expectedExists, actual); + } + + [Fact] + public async Task ItCanCreateCollectionAsync() + { + // Arrange + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + + // Act + await sut.CreateCollectionAsync(); + + // Assert + Assert.True(await sut.CollectionExistsAsync()); + } + + [Theory] + [InlineData(true, true)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(false, false)] + public async Task ItCanCreateCollectionUpsertAndGetAsync(bool includeVectors, bool useRecordDefinition) + { + // Arrange + const string HotelId = "55555555-5555-5555-5555-555555555555"; + + var collectionNamePostfix = useRecordDefinition ? "with-definition" : "with-type"; + var collectionName = $"collection-{collectionNamePostfix}"; + + var options = new MongoDBVectorStoreRecordCollectionOptions + { + VectorStoreRecordDefinition = useRecordDefinition ? fixture.HotelVectorStoreRecordDefinition : null + }; + + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, collectionName); + + var record = this.CreateTestHotel(HotelId); + + // Act + await sut.CreateCollectionAsync(); + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync(HotelId, new() { IncludeVectors = includeVectors }); + + // Assert + Assert.True(await sut.CollectionExistsAsync()); + await sut.DeleteCollectionAsync(); + + Assert.Equal(HotelId, upsertResult); + Assert.NotNull(getResult); + + Assert.Equal(record.HotelId, getResult.HotelId); + Assert.Equal(record.HotelName, getResult.HotelName); + Assert.Equal(record.HotelCode, getResult.HotelCode); + Assert.Equal(record.HotelRating, getResult.HotelRating); + Assert.Equal(record.ParkingIncluded, getResult.ParkingIncluded); + Assert.Equal(record.Tags.ToArray(), getResult.Tags.ToArray()); + Assert.Equal(record.Description, getResult.Description); + Assert.Equal(record.Timestamp.ToUniversalTime(), getResult.Timestamp.ToUniversalTime()); + + if (includeVectors) + { + Assert.NotNull(getResult.DescriptionEmbedding); + Assert.Equal(record.DescriptionEmbedding!.Value.ToArray(), getResult.DescriptionEmbedding.Value.ToArray()); + } + else + { + Assert.Null(getResult.DescriptionEmbedding); + } + } + + [Fact] + public async Task ItCanDeleteCollectionAsync() + { + // Arrange + const string TempCollectionName = "temp-test"; + await fixture.MongoDatabase.CreateCollectionAsync(TempCollectionName); + + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, TempCollectionName); + + Assert.True(await sut.CollectionExistsAsync()); + + // Act + await sut.DeleteCollectionAsync(); + + // Assert + Assert.False(await sut.CollectionExistsAsync()); + } + + [Fact] + public async Task ItCanGetAndDeleteRecordAsync() + { + // Arrange + const string HotelId = "55555555-5555-5555-5555-555555555555"; + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + + var record = this.CreateTestHotel(HotelId); + + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync(HotelId); + + Assert.Equal(HotelId, upsertResult); + Assert.NotNull(getResult); + + // Act + await sut.DeleteAsync(HotelId); + + getResult = await sut.GetAsync(HotelId); + + // Assert + Assert.Null(getResult); + } + + [Fact] + public async Task ItCanGetAndDeleteBatchAsync() + { + // Arrange + const string HotelId1 = "11111111-1111-1111-1111-111111111111"; + const string HotelId2 = "22222222-2222-2222-2222-222222222222"; + const string HotelId3 = "33333333-3333-3333-3333-333333333333"; + + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + + var record1 = this.CreateTestHotel(HotelId1); + var record2 = this.CreateTestHotel(HotelId2); + var record3 = this.CreateTestHotel(HotelId3); + + var upsertResults = await sut.UpsertBatchAsync([record1, record2, record3]).ToListAsync(); + var getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + + Assert.Equal([HotelId1, HotelId2, HotelId3], upsertResults); + + Assert.NotNull(getResults.First(l => l.HotelId == HotelId1)); + Assert.NotNull(getResults.First(l => l.HotelId == HotelId2)); + Assert.NotNull(getResults.First(l => l.HotelId == HotelId3)); + + // Act + await sut.DeleteBatchAsync([HotelId1, HotelId2, HotelId3]); + + getResults = await sut.GetBatchAsync([HotelId1, HotelId2, HotelId3]).ToListAsync(); + + // Assert + Assert.Empty(getResults); + } + + [Fact] + public async Task ItCanUpsertRecordAsync() + { + // Arrange + const string HotelId = "55555555-5555-5555-5555-555555555555"; + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + + var record = this.CreateTestHotel(HotelId); + + var upsertResult = await sut.UpsertAsync(record); + var getResult = await sut.GetAsync(HotelId); + + Assert.Equal(HotelId, upsertResult); + Assert.NotNull(getResult); + + // Act + record.HotelName = "Updated name"; + record.HotelRating = 10; + + upsertResult = await sut.UpsertAsync(record); + getResult = await sut.GetAsync(HotelId); + + // Assert + Assert.NotNull(getResult); + Assert.Equal("Updated name", getResult.HotelName); + Assert.Equal(10, getResult.HotelRating); + } + + [Fact] + public async Task UpsertWithModelWorksCorrectlyAsync() + { + // Arrange + var definition = new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Id", typeof(string)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) + } + }; + + var model = new TestModel { Id = "key", HotelName = "Test Name" }; + + var sut = new MongoDBVectorStoreRecordCollection( + fixture.MongoDatabase, + fixture.TestCollection, + new() { VectorStoreRecordDefinition = definition }); + + // Act + var upsertResult = await sut.UpsertAsync(model); + var getResult = await sut.GetAsync(model.Id); + + // Assert + Assert.Equal("key", upsertResult); + + Assert.NotNull(getResult); + Assert.Equal("key", getResult.Id); + Assert.Equal("Test Name", getResult.HotelName); + } + + [Fact] + public async Task UpsertWithVectorStoreModelWorksCorrectlyAsync() + { + // Arrange + var model = new VectorStoreTestModel { HotelId = "key", HotelName = "Test Name" }; + + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + + // Act + var upsertResult = await sut.UpsertAsync(model); + var getResult = await sut.GetAsync(model.HotelId); + + // Assert + Assert.Equal("key", upsertResult); + + Assert.NotNull(getResult); + Assert.Equal("key", getResult.HotelId); + Assert.Equal("Test Name", getResult.HotelName); + } + + [Fact] + public async Task UpsertWithBsonModelWorksCorrectlyAsync() + { + // Arrange + var definition = new VectorStoreRecordDefinition + { + Properties = new List + { + new VectorStoreRecordKeyProperty("Id", typeof(string)), + new VectorStoreRecordDataProperty("HotelName", typeof(string)) + } + }; + + var model = new BsonTestModel { Id = "key", HotelName = "Test Name" }; + + var sut = new MongoDBVectorStoreRecordCollection( + fixture.MongoDatabase, + fixture.TestCollection, + new() { VectorStoreRecordDefinition = definition }); + + // Act + var upsertResult = await sut.UpsertAsync(model); + var getResult = await sut.GetAsync(model.Id); + + // Assert + Assert.Equal("key", upsertResult); + + Assert.NotNull(getResult); + Assert.Equal("key", getResult.Id); + Assert.Equal("Test Name", getResult.HotelName); + } + + [Fact] + public async Task UpsertWithBsonVectorStoreModelWorksCorrectlyAsync() + { + // Arrange + var model = new BsonVectorStoreTestModel { HotelId = "key", HotelName = "Test Name" }; + + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + + // Act + var upsertResult = await sut.UpsertAsync(model); + var getResult = await sut.GetAsync(model.HotelId); + + // Assert + Assert.Equal("key", upsertResult); + + Assert.NotNull(getResult); + Assert.Equal("key", getResult.HotelId); + Assert.Equal("Test Name", getResult.HotelName); + } + + [Fact] + public async Task UpsertWithBsonVectorStoreWithNameModelWorksCorrectlyAsync() + { + // Arrange + var model = new BsonVectorStoreWithNameTestModel { Id = "key", HotelName = "Test Name" }; + + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, fixture.TestCollection); + + // Act + var upsertResult = await sut.UpsertAsync(model); + var getResult = await sut.GetAsync(model.Id); + + // Assert + Assert.Equal("key", upsertResult); + + Assert.NotNull(getResult); + Assert.Equal("key", getResult.Id); + Assert.Equal("Test Name", getResult.HotelName); + } + + [Fact] + public async Task VectorizedSearchReturnsValidResultsByDefaultAsync() + { + // Arrange + var hotel1 = this.CreateTestHotel(hotelId: "key1", embedding: new[] { 30f, 31f, 32f, 33f }); + var hotel2 = this.CreateTestHotel(hotelId: "key2", embedding: new[] { 31f, 32f, 33f, 34f }); + var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); + var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); + + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearch"); + + await sut.CreateCollectionIfNotExistsAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f])); + + // Assert + var searchResults = await actual.Results.ToListAsync(); + var ids = searchResults.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal("key1", ids[0]); + Assert.Equal("key2", ids[1]); + Assert.Equal("key3", ids[2]); + + Assert.DoesNotContain("key4", ids); + + Assert.Equal(1, searchResults.First(l => l.Record.HotelId == "key1").Score); + } + + [Fact] + public async Task VectorizedSearchReturnsValidResultsWithOffsetAsync() + { + // Arrange + var hotel1 = this.CreateTestHotel(hotelId: "key1", embedding: new[] { 30f, 31f, 32f, 33f }); + var hotel2 = this.CreateTestHotel(hotelId: "key2", embedding: new[] { 31f, 32f, 33f, 34f }); + var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); + var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); + + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearchWithOffset"); + + await sut.CreateCollectionIfNotExistsAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + { + Top = 2, + Skip = 2 + }); + + // Assert + var searchResults = await actual.Results.ToListAsync(); + var ids = searchResults.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal("key3", ids[0]); + Assert.Equal("key4", ids[1]); + + Assert.DoesNotContain("key1", ids); + Assert.DoesNotContain("key2", ids); + } + + [Fact] + public async Task VectorizedSearchReturnsValidResultsWithFilterAsync() + { + // Arrange + var hotel1 = this.CreateTestHotel(hotelId: "key1", embedding: new[] { 30f, 31f, 32f, 33f }); + var hotel2 = this.CreateTestHotel(hotelId: "key2", embedding: new[] { 31f, 32f, 33f, 34f }); + var hotel3 = this.CreateTestHotel(hotelId: "key3", embedding: new[] { 20f, 20f, 20f, 20f }); + var hotel4 = this.CreateTestHotel(hotelId: "key4", embedding: new[] { -1000f, -1000f, -1000f, -1000f }); + + var sut = new MongoDBVectorStoreRecordCollection(fixture.MongoDatabase, "TestVectorizedSearchWithOffset"); + + await sut.CreateCollectionIfNotExistsAsync(); + + await sut.UpsertBatchAsync([hotel4, hotel2, hotel3, hotel1]).ToListAsync(); + + // Act + var actual = await sut.VectorizedSearchAsync(new ReadOnlyMemory([30f, 31f, 32f, 33f]), new() + { + Filter = new VectorSearchFilter().EqualTo(nameof(MongoDBHotel.HotelName), "My Hotel key2") + }); + + // Assert + var searchResults = await actual.Results.ToListAsync(); + var ids = searchResults.Select(l => l.Record.HotelId).ToList(); + + Assert.Equal("key2", ids[0]); + + Assert.DoesNotContain("key1", ids); + Assert.DoesNotContain("key3", ids); + Assert.DoesNotContain("key4", ids); + } + + [Fact] + public async Task ItCanUpsertAndRetrieveUsingTheGenericMapperAsync() + { + // Arrange + var options = new MongoDBVectorStoreRecordCollectionOptions> + { + VectorStoreRecordDefinition = fixture.HotelVectorStoreRecordDefinition + }; + + var sut = new MongoDBVectorStoreRecordCollection>(fixture.MongoDatabase, fixture.TestCollection, options); + + // Act + var upsertResult = await sut.UpsertAsync(new VectorStoreGenericDataModel("GenericMapper-1") + { + Data = + { + { "HotelName", "Generic Mapper Hotel" }, + { "Description", "This is a generic mapper hotel" }, + { "Tags", new string[] { "generic" } }, + { "ParkingIncluded", false }, + { "Timestamp", new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime() }, + { "HotelRating", 3.6f } + }, + Vectors = + { + { "DescriptionEmbedding", new ReadOnlyMemory([30f, 31f, 32f, 33f]) } + } + }); + + var localGetResult = await sut.GetAsync("GenericMapper-1", new GetRecordOptions { IncludeVectors = true }); + + // Assert + Assert.NotNull(upsertResult); + Assert.Equal("GenericMapper-1", upsertResult); + + Assert.NotNull(localGetResult); + Assert.Equal("Generic Mapper Hotel", localGetResult.Data["HotelName"]); + Assert.Equal("This is a generic mapper hotel", localGetResult.Data["Description"]); + Assert.Equal(new[] { "generic" }, localGetResult.Data["Tags"]); + Assert.False((bool?)localGetResult.Data["ParkingIncluded"]); + Assert.Equal(new DateTime(1970, 1, 18, 0, 0, 0).ToUniversalTime(), localGetResult.Data["Timestamp"]); + Assert.Equal(3.6f, localGetResult.Data["HotelRating"]); + Assert.Equal(new[] { 30f, 31f, 32f, 33f }, ((ReadOnlyMemory)localGetResult.Vectors["DescriptionEmbedding"]!).ToArray()); + } + + #region private + + private MongoDBHotel CreateTestHotel(string hotelId, ReadOnlyMemory? embedding = null) + { + return new MongoDBHotel + { + HotelId = hotelId, + HotelName = $"My Hotel {hotelId}", + HotelCode = 42, + HotelRating = 4.5f, + ParkingIncluded = true, + Tags = { "t1", "t2" }, + Description = "This is a great hotel.", + Timestamp = new DateTime(2024, 09, 23, 15, 32, 33), + DescriptionEmbedding = embedding ?? new[] { 30f, 31f, 32f, 33f }, + }; + } + + private sealed class TestModel + { + public string? Id { get; set; } + + public string? HotelName { get; set; } + } + + private sealed class VectorStoreTestModel + { + [VectorStoreRecordKey] + public string? HotelId { get; set; } + + [VectorStoreRecordData(StoragePropertyName = "hotel_name")] + public string? HotelName { get; set; } + } + + private sealed class BsonTestModel + { + [BsonId] + public string? Id { get; set; } + + [BsonElement("hotel_name")] + public string? HotelName { get; set; } + } + + private sealed class BsonVectorStoreTestModel + { + [BsonId] + [VectorStoreRecordKey] + public string? HotelId { get; set; } + + [BsonElement("hotel_name")] + [VectorStoreRecordData] + public string? HotelName { get; set; } + } + + private sealed class BsonVectorStoreWithNameTestModel + { + [BsonId] + [VectorStoreRecordKey] + public string? Id { get; set; } + + [BsonElement("bson_hotel_name")] + [VectorStoreRecordData(StoragePropertyName = "storage_hotel_name")] + public string? HotelName { get; set; } + } + + #endregion +} diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreTests.cs new file mode 100644 index 000000000000..8c1ffab4fd5b --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/MongoDB/MongoDBVectorStoreTests.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.MongoDB; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.MongoDB; + +[Collection("MongoDBVectorStoreCollection")] +public class MongoDBVectorStoreTests(MongoDBVectorStoreFixture fixture) +{ + [Fact] + public async Task ItCanGetAListOfExistingCollectionNamesAsync() + { + // Arrange + var sut = new MongoDBVectorStore(fixture.MongoDatabase); + + // Act + var collectionNames = await sut.ListCollectionNamesAsync().ToListAsync(); + + // Assert + Assert.Contains("sk-test-hotels", collectionNames); + Assert.Contains("sk-test-contacts", collectionNames); + Assert.Contains("sk-test-addresses", collectionNames); + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConstants.cs b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBConstants.cs similarity index 75% rename from dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConstants.cs rename to dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBConstants.cs index 0b2ded038a89..5fdcbcd91389 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBConstants.cs +++ b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBConstants.cs @@ -2,22 +2,30 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using Microsoft.Extensions.VectorData; -namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +namespace Microsoft.SemanticKernel.Connectors.MongoDB; /// -/// Constants for Azure CosmosDB MongoDB vector store implementation. +/// Constants for MongoDB vector store implementation. /// -internal static class AzureCosmosDBMongoDBConstants +[ExcludeFromCodeCoverage] +internal static class MongoDBConstants { + /// Default ratio of number of nearest neighbors to number of documents to return. + internal const int DefaultNumCandidatesRatio = 10; + + /// Default vector index name. + internal const string DefaultVectorIndexName = "vector_index"; + /// Default index kind for vector search. internal const string DefaultIndexKind = IndexKind.IvfFlat; /// Default distance function for vector search. internal const string DefaultDistanceFunction = DistanceFunction.CosineDistance; - /// Reserved key property name in Azure CosmosDB MongoDB. + /// Reserved key property name in MongoDB. internal const string MongoReservedKeyPropertyName = "_id"; /// Reserved key property name in data model. diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBGenericDataModelMapper.cs b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBGenericDataModelMapper.cs similarity index 90% rename from dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBGenericDataModelMapper.cs rename to dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBGenericDataModelMapper.cs index adfecb696581..8ec0dffb935c 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBGenericDataModelMapper.cs +++ b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBGenericDataModelMapper.cs @@ -3,25 +3,27 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using System.Linq; using Microsoft.Extensions.VectorData; using MongoDB.Bson; -namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +namespace Microsoft.SemanticKernel.Connectors.MongoDB; /// -/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within Azure CosmosDB MongoDB. +/// A mapper that maps between the generic Semantic Kernel data model and the model that the data is stored under, within MongoDB. /// -internal sealed class AzureCosmosDBMongoDBGenericDataModelMapper : IVectorStoreRecordMapper, BsonDocument> +[ExcludeFromCodeCoverage] +internal sealed class MongoDBGenericDataModelMapper : IVectorStoreRecordMapper, BsonDocument> { /// A that defines the schema of the data in the database. private readonly VectorStoreRecordDefinition _vectorStoreRecordDefinition; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// A that defines the schema of the data in the database. - public AzureCosmosDBMongoDBGenericDataModelMapper(VectorStoreRecordDefinition vectorStoreRecordDefinition) + public MongoDBGenericDataModelMapper(VectorStoreRecordDefinition vectorStoreRecordDefinition) { Verify.NotNull(vectorStoreRecordDefinition); @@ -42,7 +44,7 @@ public BsonDocument MapFromDataToStorageModel(VectorStoreGenericDataModel MapFromStorageToDataModel(BsonDocumen if (property is VectorStoreRecordKeyProperty keyProperty) { - if (storageModel.TryGetValue(AzureCosmosDBMongoDBConstants.MongoReservedKeyPropertyName, out var keyValue)) + if (storageModel.TryGetValue(MongoDBConstants.MongoReservedKeyPropertyName, out var keyValue)) { key = keyValue.AsString; } diff --git a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordMapper.cs b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBVectorStoreRecordMapper.cs similarity index 58% rename from dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordMapper.cs rename to dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBVectorStoreRecordMapper.cs index d0eb96d9bcea..2ddb4f594fd7 100644 --- a/dotnet/src/Connectors/Connectors.Memory.AzureCosmosDBMongoDB/AzureCosmosDBMongoDBVectorStoreRecordMapper.cs +++ b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoDBVectorStoreRecordMapper.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Diagnostics.CodeAnalysis; using System.Reflection; using Microsoft.Extensions.VectorData; using MongoDB.Bson; @@ -8,9 +9,10 @@ using MongoDB.Bson.Serialization.Attributes; using MongoDB.Bson.Serialization.Conventions; -namespace Microsoft.SemanticKernel.Connectors.AzureCosmosDBMongoDB; +namespace Microsoft.SemanticKernel.Connectors.MongoDB; -internal sealed class AzureCosmosDBMongoDBVectorStoreRecordMapper : IVectorStoreRecordMapper +[ExcludeFromCodeCoverage] +internal sealed class MongoDBVectorStoreRecordMapper : IVectorStoreRecordMapper { /// A key property info of the data model. private readonly PropertyInfo _keyProperty; @@ -19,14 +21,14 @@ internal sealed class AzureCosmosDBMongoDBVectorStoreRecordMapper : IVe private readonly string _keyPropertyName; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// A helper to access property information for the current data model and record definition. - public AzureCosmosDBMongoDBVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyReader) + public MongoDBVectorStoreRecordMapper(VectorStoreRecordPropertyReader propertyReader) { - propertyReader.VerifyKeyProperties(AzureCosmosDBMongoDBConstants.SupportedKeyTypes); - propertyReader.VerifyDataProperties(AzureCosmosDBMongoDBConstants.SupportedDataTypes, supportEnumerable: true); - propertyReader.VerifyVectorProperties(AzureCosmosDBMongoDBConstants.SupportedVectorTypes); + propertyReader.VerifyKeyProperties(MongoDBConstants.SupportedKeyTypes); + propertyReader.VerifyDataProperties(MongoDBConstants.SupportedDataTypes, supportEnumerable: true); + propertyReader.VerifyVectorProperties(MongoDBConstants.SupportedVectorTypes); this._keyPropertyName = propertyReader.KeyPropertyName; this._keyProperty = propertyReader.KeyPropertyInfo; @@ -37,7 +39,7 @@ public AzureCosmosDBMongoDBVectorStoreRecordMapper(VectorStoreRecordPropertyRead }; ConventionRegistry.Register( - nameof(AzureCosmosDBMongoDBVectorStoreRecordMapper), + nameof(MongoDBVectorStoreRecordMapper), conventionPack, type => type == typeof(TRecord)); } @@ -47,13 +49,13 @@ public BsonDocument MapFromDataToStorageModel(TRecord dataModel) var document = dataModel.ToBsonDocument(); // Handle key property mapping due to reserved key name in Mongo. - if (!document.Contains(AzureCosmosDBMongoDBConstants.MongoReservedKeyPropertyName)) + if (!document.Contains(MongoDBConstants.MongoReservedKeyPropertyName)) { var value = document[this._keyPropertyName]; document.Remove(this._keyPropertyName); - document[AzureCosmosDBMongoDBConstants.MongoReservedKeyPropertyName] = value; + document[MongoDBConstants.MongoReservedKeyPropertyName] = value; } return document; @@ -62,12 +64,12 @@ public BsonDocument MapFromDataToStorageModel(TRecord dataModel) public TRecord MapFromStorageToDataModel(BsonDocument storageModel, StorageToDataModelMapperOptions options) { // Handle key property mapping due to reserved key name in Mongo. - if (!this._keyPropertyName.Equals(AzureCosmosDBMongoDBConstants.DataModelReservedKeyPropertyName, StringComparison.OrdinalIgnoreCase) && + if (!this._keyPropertyName.Equals(MongoDBConstants.DataModelReservedKeyPropertyName, StringComparison.OrdinalIgnoreCase) && this._keyProperty.GetCustomAttribute() is null) { - var value = storageModel[AzureCosmosDBMongoDBConstants.MongoReservedKeyPropertyName]; + var value = storageModel[MongoDBConstants.MongoReservedKeyPropertyName]; - storageModel.Remove(AzureCosmosDBMongoDBConstants.MongoReservedKeyPropertyName); + storageModel.Remove(MongoDBConstants.MongoReservedKeyPropertyName); storageModel[this._keyPropertyName] = value; }