diff --git a/go.sum b/go.sum index 996b4cf..45c0b05 100644 --- a/go.sum +++ b/go.sum @@ -30,8 +30,6 @@ github.com/PuerkitoBio/goquery v1.9.2 h1:4/wZksC3KgkQw7SQgkKotmKljk0M6V8TUvA8Wb4 github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk= github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss= github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= -github.com/ansys/aali-sharedtypes v1.0.4-0.20250912172539-2fcf3e45b5ae h1:2Euh3dbT5zULzOMId2fwvbZKcaBNCJhQzgLs9k3Fqag= -github.com/ansys/aali-sharedtypes v1.0.4-0.20250912172539-2fcf3e45b5ae/go.mod h1:Ze0xVXbyl63d/dN95UKHJjoGs7ZmF5OrykBtKQxgO1U= github.com/ansys/aali-sharedtypes v1.0.4-0.20250924070211-69c801e57963 h1:06/rH2IPd92eNvERw3TIEWGT4i/5F214/qLKUjZpzx4= github.com/ansys/aali-sharedtypes v1.0.4-0.20250924070211-69c801e57963/go.mod h1:Ze0xVXbyl63d/dN95UKHJjoGs7ZmF5OrykBtKQxgO1U= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= diff --git a/pkg/externalfunctions/llmhandler.go b/pkg/externalfunctions/llmhandler.go index dea9418..684dd58 100644 --- a/pkg/externalfunctions/llmhandler.go +++ b/pkg/externalfunctions/llmhandler.go @@ -1088,13 +1088,12 @@ func BuildFinalQueryForGeneralLLMRequest(request string, knowledgedbResponse []s } // Build the final query using the KnowledgeDB response and the original request + // Append all non-empty fields to provide maximum context from comprehensive DbResponse finalQuery = "Based on the following examples:\n\n--- INFO START ---\n" - for _, example := range knowledgedbResponse { - finalQuery += example.Text + "\n" + for i, example := range knowledgedbResponse { + finalQuery += dbResponsePromptFormat(example, i+1) + "\n\n" } finalQuery += "--- INFO END ---\n\n" + request + "\n" - - // Return the final query return finalQuery } @@ -1112,45 +1111,13 @@ func BuildFinalQueryForGeneralLLMRequest(request string, knowledgedbResponse []s // Returns: // - finalQuery: the final query func BuildFinalQueryForCodeLLMRequest(request string, knowledgedbResponse []sharedtypes.DbResponse) (finalQuery string) { - // Build the final query using the KnowledgeDB response and the original request - // We have to use the text from the DB response and the original request. - // - // The prompt should be in the following format: - // - // ****************************************************************************** - // Based on the following examples: - // - // --- START EXAMPLE {response_n}--- - // >>> Summary: - // {knowledge_db_response_n_summary} - // - // >>> Code snippet: - // ```python - // {knowledge_db_response_n_text} - // ``` - // --- END EXAMPLE {response_n}--- - // - // --- START EXAMPLE {response_n}--- - // ... - // --- END EXAMPLE {response_n}--- - // - // Generate the Python code for the following request: - // - // >>> Request: - // {original_request} - // ****************************************************************************** - // If there is no response from the KnowledgeDB, return the original request if len(knowledgedbResponse) > 0 { // Initial request finalQuery = "Based on the following examples:\n\n" for i, element := range knowledgedbResponse { - // Add the example number - finalQuery += "--- START EXAMPLE " + fmt.Sprint(i+1) + "---\n" - finalQuery += ">>> Summary:\n" + element.Summary + "\n\n" - finalQuery += ">>> Code snippet:\n```python\n" + element.Text + "\n```\n" - finalQuery += "--- END EXAMPLE " + fmt.Sprint(i+1) + "---\n\n" + finalQuery += dbResponsePromptFormat(element, i+1) + "\n\n" } } @@ -1161,6 +1128,123 @@ func BuildFinalQueryForCodeLLMRequest(request string, knowledgedbResponse []shar return finalQuery } +type CollectionType uint8 + +const ( + Unknown CollectionType = iota + ApiElement + UserGuide + Example +) + +func (coll CollectionType) String() string { + switch coll { + case ApiElement: + return "API ELEMENT" + case Example: + return "EXAMPLE" + case UserGuide: + return "USER GUIDE" + default: + return "UNKNOWN" + } +} + +// Format a DbResponse as a string to include in the context. +// +// This takes a best guess at what type of document the DbResponse represents (API element/user guide/example) +// and then formats it accordingly. +func dbResponsePromptFormat(dbresponse sharedtypes.DbResponse, num int) string { + var contentParts []string + + // Determine collection type and add appropriate header + collectionType := Unknown + if dbresponse.Type != "" && (dbresponse.NameFormatted != "" || dbresponse.Name != "") { + collectionType = ApiElement + } else if dbresponse.Title != "" && dbresponse.SectionName != "" { + collectionType = UserGuide + } else if dbresponse.DocumentName != "" { + collectionType = Example + } else { + logging.Log.Warnf(&logging.ContextMap{}, "was unable to determine format of DB response for prompt formatting %v", dbresponse) + } + + contentParts = append(contentParts, fmt.Sprintf("=== START %s #%d ===", collectionType, num)) + + // CodeGenerationElement fields (API/Element collections) + if collectionType == ApiElement { + if dbresponse.NameFormatted != "" { + contentParts = append(contentParts, "API: "+dbresponse.NameFormatted) + } + if dbresponse.NamePseudocode != "" { + contentParts = append(contentParts, "Function: "+dbresponse.NamePseudocode) + } + if dbresponse.Name != "" { + contentParts = append(contentParts, "Full Name: "+dbresponse.Name) + } + if dbresponse.Type != "" { + contentParts = append(contentParts, "Type: "+dbresponse.Type) + } + if dbresponse.ParentClass != "" { + contentParts = append(contentParts, "Parent: "+dbresponse.ParentClass) + } + } + + // VectorDatabaseUserGuideSection fields (User Guide collections) + if collectionType == UserGuide { + contentParts = append(contentParts, "Guide Title: "+dbresponse.Title) + contentParts = append(contentParts, "Section: "+dbresponse.SectionName) + if dbresponse.ParentSectionName != "" { + contentParts = append(contentParts, "Parent Section: "+dbresponse.ParentSectionName) + } + } + + // VectorDatabaseExample fields (Example collections) + if collectionType == Example { + contentParts = append(contentParts, "Example File: "+dbresponse.DocumentName) + // Convert []interface{} to []string for joining + var deps []string + for _, dep := range dbresponse.Dependencies { + if depStr, ok := dep.(string); ok { + deps = append(deps, depStr) + } + } + if len(deps) > 0 { + contentParts = append(contentParts, "Uses APIs: "+strings.Join(deps, ", ")) + } + } + + // Common fields + if dbresponse.DocumentName != "" { + contentParts = append(contentParts, "Document: "+dbresponse.DocumentName) + } + if dbresponse.Summary != "" { + contentParts = append(contentParts, "Summary: "+dbresponse.Summary) + } + if len(dbresponse.Keywords) > 0 { + contentParts = append(contentParts, "Keywords: "+strings.Join(dbresponse.Keywords, ", ")) + } + if len(dbresponse.Tags) > 0 { + contentParts = append(contentParts, "Tags: "+strings.Join(dbresponse.Tags, ", ")) + } + + // Handle Text field with proper formatting based on collection type + if dbresponse.Text != "" { + if collectionType == Example { + contentParts = append(contentParts, "Code:") + contentParts = append(contentParts, "```python") + contentParts = append(contentParts, dbresponse.Text) + contentParts = append(contentParts, "```") + } else { + contentParts = append(contentParts, "Content:") + contentParts = append(contentParts, dbresponse.Text) + } + } + + contentParts = append(contentParts, fmt.Sprintf("=== END %s #%d ===", collectionType, num)) + return strings.Join(contentParts, "\n") +} + type AppendMessageHistoryRole string const ( diff --git a/pkg/externalfunctions/llmhandler_test.go b/pkg/externalfunctions/llmhandler_test.go new file mode 100644 index 0000000..ae05ceb --- /dev/null +++ b/pkg/externalfunctions/llmhandler_test.go @@ -0,0 +1,122 @@ +// Copyright (C) 2025 ANSYS, Inc. and/or its affiliates. +// SPDX-License-Identifier: MIT +// +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package externalfunctions + +import ( + "strings" + "testing" + + "github.com/ansys/aali-sharedtypes/pkg/sharedtypes" + "github.com/stretchr/testify/assert" +) + +func TestDbResponsePromptFormat(t *testing.T) { + testCases := []struct { + name string + dbResponse sharedtypes.DbResponse + expectedLines []string + }{ + { + "API Element", + sharedtypes.DbResponse{ + Name: "M:Namespace.Class.MyCusomMethod(System.String,System.Int32)", + NamePseudocode: "MyCustomMethod", + NameFormatted: "My Custom Method", + Type: "Method", + ParentClass: "Namespace.Class", + Metadata: map[string]any{}, + }, + []string{ + "=== START API ELEMENT #1 ===", + "API: My Custom Method", + "Function: MyCustomMethod", + "Full Name: M:Namespace.Class.MyCusomMethod(System.String,System.Int32)", + "Type: Method", + "Parent: Namespace.Class", + "=== END API ELEMENT #1 ===", + }, + }, + { + "Example", + sharedtypes.DbResponse{ + DocumentName: "examples/my_example.py", + Text: "import random\n\ndef main():\n print(random.randint(0, 10))\n\nif __name__ == '__main__':\n main()\n", + PreviousChunk: "previous-chunk-id", + NextChunk: "next-chunk-id", + Dependencies: []any{"random"}, + DependencyEquivalences: map[string]any{"random": "random-equiv"}, + }, + []string{ + "=== START EXAMPLE #1 ===", + "Example File: examples/my_example.py", + "Uses APIs: random", + "Document: examples/my_example.py", + "Code:", + "```python", + "import random", + "", + "def main():", + " print(random.randint(0, 10))", + "", + "if __name__ == '__main__':", + " main()", + "", + "```", + "=== END EXAMPLE #1 ===", + }, + }, + { + "User Guide", + sharedtypes.DbResponse{ + SectionName: "Section", + DocumentName: "user_guide", + Title: "Title", + ParentSectionName: "Parent", + Level: "2", + Text: "Here is the user\nguide content", + PreviousChunk: "prev-chunk", + NextChunk: "next-chunk", + }, + []string{ + "=== START USER GUIDE #1 ===", + "Guide Title: Title", + "Section: Section", + "Parent Section: Parent", + "Document: user_guide", + "Content:", + "Here is the user", + "guide content", + "=== END USER GUIDE #1 ===", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := dbResponsePromptFormat(tc.dbResponse, 1) + expected := strings.Join(tc.expectedLines, "\n") + assert.Equal(t, expected, result) + }) + } + +}