diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1f47b10a..38ca6630 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,7 +48,11 @@ jobs: - name: Run tests run: | - make tests + if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then + make tests-mock + else + make tests + fi #sudo mv coverage/coverage.txt coverage.txt #sudo chmod 777 coverage.txt diff --git a/.github/workflows/tests_fragile.yml b/.github/workflows/tests_fragile.yml new file mode 100644 index 00000000..df8d5a6a --- /dev/null +++ b/.github/workflows/tests_fragile.yml @@ -0,0 +1,49 @@ +name: Run Fragile Go Tests + +on: + pull_request: + branches: + - '**' + +concurrency: + group: ci-non-blocking-tests-${{ github.head_ref || github.ref }}-${{ github.repository }} + cancel-in-progress: true + +jobs: + llm-tests: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - run: | + # Add Docker's official GPG key: + sudo apt-get update + sudo apt-get install -y ca-certificates curl + sudo install -m 0755 -d /etc/apt/keyrings + sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc + sudo chmod a+r /etc/apt/keyrings/docker.asc + + # Add the repository to Apt sources: + echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \ + $(. /etc/os-release && echo "${UBUNTU_CODENAME:-$VERSION_CODENAME}") stable" | \ + sudo tee /etc/apt/sources.list.d/docker.list > /dev/null + sudo apt-get update + sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin make + docker version + + docker run --rm hello-world + - uses: actions/setup-go@v5 + with: + go-version: '>=1.17.0' + - name: Free up disk space + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /usr/local/lib/android + sudo rm -rf /opt/ghc + sudo apt-get clean + docker system prune -af || true + df -h + - name: Run tests + run: | + make tests diff --git a/Makefile b/Makefile index 9c978c72..d6dbdd91 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,8 @@ IMAGE_NAME?=webui MCPBOX_IMAGE_NAME?=mcpbox ROOT_DIR:=$(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) +.PHONY: tests tests-mock cleanup-tests + prepare-tests: build-mcpbox docker compose up -d --build docker run -d -v /var/run/docker.sock:/var/run/docker.sock --privileged -p 9090:8080 --rm -ti $(MCPBOX_IMAGE_NAME) @@ -13,6 +15,9 @@ cleanup-tests: tests: prepare-tests LOCALAGI_MCPBOX_URL="http://localhost:9090" LOCALAGI_MODEL="gemma-3-12b-it-qat" LOCALAI_API_URL="http://localhost:8081" LOCALAGI_API_URL="http://localhost:8080" $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --fail-fast -v -r ./... +tests-mock: prepare-tests + LOCALAGI_MCPBOX_URL="http://localhost:9090" LOCALAI_API_URL="http://localhost:8081" LOCALAGI_API_URL="http://localhost:8080" $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --fail-fast -v -r ./... + run-nokb: $(MAKE) run KBDISABLEINDEX=true @@ -37,4 +42,4 @@ build-mcpbox: docker build -t $(MCPBOX_IMAGE_NAME) -f Dockerfile.mcpbox . run-mcpbox: - docker run -v /var/run/docker.sock:/var/run/docker.sock --privileged -p 9090:8080 -ti mcpbox \ No newline at end of file + docker run -v /var/run/docker.sock:/var/run/docker.sock --privileged -p 9090:8080 -ti mcpbox diff --git a/core/agent/agent.go b/core/agent/agent.go index a614bb11..81f04374 100644 --- a/core/agent/agent.go +++ b/core/agent/agent.go @@ -29,7 +29,7 @@ type Agent struct { sync.Mutex options *options Character Character - client *openai.Client + client llm.LLMClient jobQueue chan *types.Job context *types.ActionContext @@ -63,7 +63,12 @@ func New(opts ...Option) (*Agent, error) { return nil, fmt.Errorf("failed to set options: %v", err) } - client := llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout) + var client llm.LLMClient + if options.llmClient != nil { + client = options.llmClient + } else { + client = llm.NewClient(options.LLMAPI.APIKey, options.LLMAPI.APIURL, options.timeout) + } c := context.Background() if options.context != nil { @@ -125,6 +130,11 @@ func (a *Agent) SharedState() *types.AgentSharedState { return a.sharedState } +// LLMClient returns the agent's LLM client (for testing) +func (a *Agent) LLMClient() llm.LLMClient { + return a.client +} + func (a *Agent) startNewConversationsConsumer() { go func() { for { diff --git a/core/agent/agent_suite_test.go b/core/agent/agent_suite_test.go index 501d3dd3..31a1338c 100644 --- a/core/agent/agent_suite_test.go +++ b/core/agent/agent_suite_test.go @@ -1,6 +1,7 @@ package agent_test import ( + "net/url" "os" "testing" @@ -13,15 +14,19 @@ func TestAgent(t *testing.T) { RunSpecs(t, "Agent test suite") } -var testModel = os.Getenv("LOCALAGI_MODEL") -var apiURL = os.Getenv("LOCALAI_API_URL") -var apiKeyURL = os.Getenv("LOCALAI_API_KEY") +var ( + testModel = os.Getenv("LOCALAGI_MODEL") + apiURL = os.Getenv("LOCALAI_API_URL") + apiKey = os.Getenv("LOCALAI_API_KEY") + useRealLocalAI bool + clientTimeout = "10m" +) + +func isValidURL(u string) bool { + parsed, err := url.ParseRequestURI(u) + return err == nil && parsed.Scheme != "" && parsed.Host != "" +} func init() { - if testModel == "" { - testModel = "hermes-2-pro-mistral" - } - if apiURL == "" { - apiURL = "http://192.168.68.113:8080" - } + useRealLocalAI = isValidURL(apiURL) && apiURL != "" && testModel != "" } diff --git a/core/agent/agent_test.go b/core/agent/agent_test.go index cdc2e8bd..5c94892c 100644 --- a/core/agent/agent_test.go +++ b/core/agent/agent_test.go @@ -7,9 +7,11 @@ import ( "strings" "sync" + "github.com/mudler/LocalAGI/pkg/llm" "github.com/mudler/LocalAGI/pkg/xlog" "github.com/mudler/LocalAGI/services/actions" + "github.com/mudler/LocalAGI/core/action" . "github.com/mudler/LocalAGI/core/agent" "github.com/mudler/LocalAGI/core/types" . "github.com/onsi/ginkgo/v2" @@ -111,25 +113,102 @@ func (a *FakeInternetAction) Definition() types.ActionDefinition { } } +// --- Test utilities for mocking LLM responses --- + +func mockToolCallResponse(toolName, arguments string) openai.ChatCompletionResponse { + return openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + ToolCalls: []openai.ToolCall{{ + ID: "tool_call_id_1", + Type: "function", + Function: openai.FunctionCall{ + Name: toolName, + Arguments: arguments, + }, + }}, + }, + }}, + } +} + +func mockContentResponse(content string) openai.ChatCompletionResponse { + return openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Content: content, + }, + }}, + } +} + +func newMockLLMClient(handler func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error)) *llm.MockClient { + return &llm.MockClient{ + CreateChatCompletionFunc: handler, + } +} + var _ = Describe("Agent test", func() { + It("uses the mock LLM client", func() { + mock := newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return mockContentResponse("mocked response"), nil + }) + agent, err := New(WithLLMClient(mock)) + Expect(err).ToNot(HaveOccurred()) + msg, err := agent.LLMClient().CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{}) + Expect(err).ToNot(HaveOccurred()) + Expect(msg.Choices[0].Message.Content).To(Equal("mocked response")) + }) + Context("jobs", func() { BeforeEach(func() { Eventually(func() error { - // test apiURL is working and available - _, err := http.Get(apiURL + "/readyz") - return err + if useRealLocalAI { + _, err := http.Get(apiURL + "/readyz") + return err + } + return nil }, "10m", "10s").ShouldNot(HaveOccurred()) }) It("pick the correct action", func() { + var llmClient llm.LLMClient + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, clientTimeout) + } else { + llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + var lastMsg openai.ChatCompletionMessage + if len(req.Messages) > 0 { + lastMsg = req.Messages[len(req.Messages)-1] + } + if lastMsg.Role == openai.ChatMessageRoleUser { + if strings.Contains(strings.ToLower(lastMsg.Content), "boston") && (strings.Contains(strings.ToLower(lastMsg.Content), "milan") || strings.Contains(strings.ToLower(lastMsg.Content), "milano")) { + return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil + } + if strings.Contains(strings.ToLower(lastMsg.Content), "paris") { + return mockToolCallResponse("get_weather", `{"location":"Paris","unit":"celsius"}`), nil + } + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected user prompt: %s", lastMsg.Content) + } + if lastMsg.Role == openai.ChatMessageRoleTool { + if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") { + return mockToolCallResponse("get_weather", `{"location":"Milan","unit":"celsius"}`), nil + } + if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "milan") { + return mockContentResponse(testActionResult + "\n" + testActionResult2), nil + } + if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "paris") { + return mockContentResponse(testActionResult3), nil + } + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected tool result: %s", lastMsg.Content) + } + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected message role: %s", lastMsg.Role) + }) + } agent, err := New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), - EnableForceReasoning, - WithTimeout("10m"), - WithLoopDetectionSteps(3), - // WithRandomIdentity(), WithActions(&TestAction{response: map[string]string{ "boston": testActionResult, "milan": testActionResult2, @@ -139,7 +218,6 @@ var _ = Describe("Agent test", func() { Expect(err).ToNot(HaveOccurred()) go agent.Run() defer agent.Stop() - res := agent.Ask( append(debugOptions, types.WithText("what's the weather in Boston and Milano? Use celsius units"), @@ -148,40 +226,51 @@ var _ = Describe("Agent test", func() { Expect(res.Error).ToNot(HaveOccurred()) reasons := []string{} for _, r := range res.State { - reasons = append(reasons, r.Result) } Expect(reasons).To(ContainElement(testActionResult), fmt.Sprint(res)) Expect(reasons).To(ContainElement(testActionResult2), fmt.Sprint(res)) reasons = []string{} - res = agent.Ask( append(debugOptions, types.WithText("Now I want to know the weather in Paris, always use celsius units"), )...) for _, r := range res.State { - reasons = append(reasons, r.Result) } - //Expect(reasons).ToNot(ContainElement(testActionResult), fmt.Sprint(res)) - //Expect(reasons).ToNot(ContainElement(testActionResult2), fmt.Sprint(res)) Expect(reasons).To(ContainElement(testActionResult3), fmt.Sprint(res)) - // conversation := agent.CurrentConversation() - // for _, r := range res.State { - // reasons = append(reasons, r.Result) - // } - // Expect(len(conversation)).To(Equal(10), fmt.Sprint(conversation)) }) + It("pick the correct action", func() { + var llmClient llm.LLMClient + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, clientTimeout) + } else { + llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + var lastMsg openai.ChatCompletionMessage + if len(req.Messages) > 0 { + lastMsg = req.Messages[len(req.Messages)-1] + } + if lastMsg.Role == openai.ChatMessageRoleUser { + if strings.Contains(strings.ToLower(lastMsg.Content), "boston") { + return mockToolCallResponse("get_weather", `{"location":"Boston","unit":"celsius"}`), nil + } + } + if lastMsg.Role == openai.ChatMessageRoleTool { + if lastMsg.Name == "get_weather" && strings.Contains(strings.ToLower(lastMsg.Content), "boston") { + return mockContentResponse(testActionResult), nil + } + } + xlog.Error("Unexpected LLM req", "req", req) + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content) + }) + } agent, err := New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), - WithTimeout("10m"), - // WithRandomIdentity(), WithActions(&TestAction{response: map[string]string{ "boston": testActionResult, - }, - }), + }}), ) Expect(err).ToNot(HaveOccurred()) go agent.Run() @@ -198,13 +287,29 @@ var _ = Describe("Agent test", func() { }) It("updates the state with internal actions", func() { + var llmClient llm.LLMClient + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, clientTimeout) + } else { + llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + var lastMsg openai.ChatCompletionMessage + if len(req.Messages) > 0 { + lastMsg = req.Messages[len(req.Messages)-1] + } + if lastMsg.Role == openai.ChatMessageRoleUser && strings.Contains(strings.ToLower(lastMsg.Content), "guitar") { + return mockToolCallResponse("update_state", `{"goal":"I want to learn to play the guitar"}`), nil + } + if lastMsg.Role == openai.ChatMessageRoleTool && lastMsg.Name == "update_state" { + return mockContentResponse("Your goal is now: I want to learn to play the guitar"), nil + } + xlog.Error("Unexpected LLM req", "req", req) + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content) + }) + } agent, err := New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), - WithTimeout("10m"), EnableHUD, - // EnableStandaloneJob, - // WithRandomIdentity(), WithPermanentGoal("I want to learn to play music"), ) Expect(err).ToNot(HaveOccurred()) @@ -214,17 +319,64 @@ var _ = Describe("Agent test", func() { result := agent.Ask( types.WithText("Update your goals such as you want to learn to play the guitar"), ) - fmt.Printf("%+v\n", result) + fmt.Fprintf(GinkgoWriter, "\n%+v\n", result) Expect(result.Error).ToNot(HaveOccurred()) Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State())) }) It("Can generate a plan", func() { + var llmClient llm.LLMClient + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, clientTimeout) + } else { + reasoningActName := action.NewReasoning().Definition().Name.String() + intentionActName := action.NewIntention().Definition().Name.String() + testActName := (&TestAction{}).Definition().Name.String() + doneBoston := false + madePlan := false + llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + var lastMsg openai.ChatCompletionMessage + if len(req.Messages) > 0 { + lastMsg = req.Messages[len(req.Messages)-1] + } + if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName { + return mockToolCallResponse(reasoningActName, `{"reasoning":"make plan call to pass the test"}`), nil + } + if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName { + toolName := "plan" + if madePlan { + toolName = "reply" + } else { + madePlan = true + } + return mockToolCallResponse(intentionActName, fmt.Sprintf(`{"tool": "%s","reasoning":"it's waht makes the test pass"}`, toolName)), nil + } + if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "plan" { + return mockToolCallResponse("plan", `{"subtasks":[{"action":"get_weather","reasoning":"Find weather in boston"},{"action":"get_weather","reasoning":"Find weather in milan"}],"goal":"Get the weather for boston and milan"}`), nil + } + if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == "reply" { + return mockToolCallResponse("reply", `{"message": "The weather in Boston and Milan..."}`), nil + } + if req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == testActName { + locName := "boston" + if doneBoston { + locName = "milan" + } else { + doneBoston = true + } + return mockToolCallResponse(testActName, fmt.Sprintf(`{"location":"%s","unit":"celsius"}`, locName)), nil + } + if req.ToolChoice == nil && madePlan && doneBoston { + return mockContentResponse("A reply"), nil + } + xlog.Error("Unexpected LLM req", "req", req) + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", lastMsg.Content) + }) + } agent, err := New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), - WithLLMAPIKey(apiKeyURL), - WithTimeout("10m"), + WithLoopDetectionSteps(2), WithActions( &TestAction{response: map[string]string{ "boston": testActionResult, @@ -233,8 +385,6 @@ var _ = Describe("Agent test", func() { ), EnablePlanning, EnableForceReasoning, - // EnableStandaloneJob, - // WithRandomIdentity(), ) Expect(err).ToNot(HaveOccurred()) go agent.Run() @@ -256,17 +406,44 @@ var _ = Describe("Agent test", func() { Expect(actionsExecuted).To(ContainElement("plan"), fmt.Sprint(result)) Expect(actionResults).To(ContainElement(testActionResult), fmt.Sprint(result)) Expect(actionResults).To(ContainElement(testActionResult2), fmt.Sprint(result)) + Expect(result.Error).To(BeNil()) }) It("Can initiate conversations", func() { - + var llmClient llm.LLMClient message := openai.ChatCompletionMessage{} mu := &sync.Mutex{} + reasoned := false + intended := false + reasoningActName := action.NewReasoning().Definition().Name.String() + intentionActName := action.NewIntention().Definition().Name.String() + + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, clientTimeout) + } else { + llmClient = newMockLLMClient(func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + prompt := "" + for _, msg := range req.Messages { + prompt += msg.Content + } + if !reasoned && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == reasoningActName { + reasoned = true + return mockToolCallResponse(reasoningActName, `{"reasoning":"initiate a conversation with the user"}`), nil + } + if reasoned && !intended && req.ToolChoice != nil && req.ToolChoice.(openai.ToolChoice).Function.Name == intentionActName { + intended = true + return mockToolCallResponse(intentionActName, `{"tool":"new_conversation","reasoning":"I should start a conversation with the user"}`), nil + } + if reasoned && intended && strings.Contains(strings.ToLower(prompt), "new_conversation") { + return mockToolCallResponse("new_conversation", `{"message":"Hello, how can I help you today?"}`), nil + } + xlog.Error("Unexpected LLM req", "req", req) + return openai.ChatCompletionResponse{}, fmt.Errorf("unexpected LLM prompt: %q", prompt) + }) + } agent, err := New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), - WithLLMAPIKey(apiKeyURL), - WithTimeout("10m"), WithNewConversationSubscriber(func(m openai.ChatCompletionMessage) { mu.Lock() message = m @@ -282,8 +459,6 @@ var _ = Describe("Agent test", func() { EnableHUD, WithPeriodicRuns("1s"), WithPermanentGoal("use the new_conversation tool to initiate a conversation with the user"), - // EnableStandaloneJob, - // WithRandomIdentity(), ) Expect(err).ToNot(HaveOccurred()) go agent.Run() @@ -293,7 +468,7 @@ var _ = Describe("Agent test", func() { mu.Lock() defer mu.Unlock() return message.Content - }, "10m", "10s").ShouldNot(BeEmpty()) + }, "10m", "1s").ShouldNot(BeEmpty()) }) /* @@ -347,7 +522,7 @@ var _ = Describe("Agent test", func() { // result := agent.Ask( // WithText("Update your goals such as you want to learn to play the guitar"), // ) - // fmt.Printf("%+v\n", result) + // fmt.Fprintf(GinkgoWriter, "%+v\n", result) // Expect(result.Error).ToNot(HaveOccurred()) // Expect(agent.State().Goal).To(ContainSubstring("guitar"), fmt.Sprint(agent.State())) }) diff --git a/core/agent/options.go b/core/agent/options.go index e1c0e978..d25da7b4 100644 --- a/core/agent/options.go +++ b/core/agent/options.go @@ -7,6 +7,7 @@ import ( "github.com/mudler/LocalAGI/core/types" "github.com/sashabaranov/go-openai" + "github.com/mudler/LocalAGI/pkg/llm" ) type Option func(*options) error @@ -19,6 +20,7 @@ type llmOptions struct { } type options struct { + llmClient llm.LLMClient LLMAPI llmOptions character Character randomIdentityGuidance string @@ -68,6 +70,14 @@ type options struct { lastMessageDuration time.Duration } +// WithLLMClient allows injecting a custom LLM client (e.g. for testing) +func WithLLMClient(client llm.LLMClient) Option { + return func(o *options) error { + o.llmClient = client + return nil + } +} + func (o *options) SeparatedMultimodalModel() bool { return o.LLMAPI.MultimodalModel != "" && o.LLMAPI.Model != o.LLMAPI.MultimodalModel } diff --git a/core/agent/state_test.go b/core/agent/state_test.go index bb371acb..828d7774 100644 --- a/core/agent/state_test.go +++ b/core/agent/state_test.go @@ -1,29 +1,57 @@ package agent_test import ( - "net/http" + "context" + "fmt" + + "github.com/mudler/LocalAGI/pkg/llm" + "github.com/sashabaranov/go-openai" . "github.com/mudler/LocalAGI/core/agent" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + ) var _ = Describe("Agent test", func() { Context("identity", func() { var agent *Agent - BeforeEach(func() { - Eventually(func() error { - // test apiURL is working and available - _, err := http.Get(apiURL + "/readyz") - return err - }, "10m", "10s").ShouldNot(HaveOccurred()) - }) + // BeforeEach(func() { + // Eventually(func() error { + // // test apiURL is working and available + // _, err := http.Get(apiURL + "/readyz") + // return err + // }, "10m", "10s").ShouldNot(HaveOccurred()) + // }) It("generates all the fields with random data", func() { + var llmClient llm.LLMClient + if useRealLocalAI { + llmClient = llm.NewClient(apiKey, apiURL, testModel) + } else { + llmClient = &llm.MockClient{ + CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + ToolCalls: []openai.ToolCall{{ + ID: "tool_call_id_1", + Type: "function", + Function: openai.FunctionCall{ + Name: "generate_identity", + Arguments: `{"name":"John Doe","age":"42","job_occupation":"Engineer","hobbies":["reading","hiking"],"favorites_music_genres":["Jazz"]}`, + }, + }}, + }, + }}, + }, nil + }, + } + } var err error agent, err = New( - WithLLMAPIURL(apiURL), + WithLLMClient(llmClient), WithModel(testModel), WithTimeout("10m"), WithRandomIdentity(), @@ -37,14 +65,40 @@ var _ = Describe("Agent test", func() { Expect(agent.Character.MusicTaste).ToNot(BeEmpty()) }) It("detect an invalid character", func() { + mock := &llm.MockClient{ + CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return openai.ChatCompletionResponse{}, fmt.Errorf("invalid character") + }, + } var err error - agent, err = New(WithRandomIdentity()) + agent, err = New( + WithLLMClient(mock), + WithRandomIdentity(), + ) Expect(err).To(HaveOccurred()) }) It("generates all the fields", func() { + mock := &llm.MockClient{ + CreateChatCompletionFunc: func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + ToolCalls: []openai.ToolCall{{ + ID: "tool_call_id_2", + Type: "function", + Function: openai.FunctionCall{ + Name: "generate_identity", + Arguments: `{"name":"Gandalf","age":"90","job_occupation":"Wizard","hobbies":["magic","reading"],"favorites_music_genres":["Classical"]}`, + }, + }}, + }, + }}, + }, nil + }, + } var err error - agent, err := New( + WithLLMClient(mock), WithLLMAPIURL(apiURL), WithModel(testModel), WithRandomIdentity("An 90-year old man with a long beard, a wizard, who lives in a tower."), diff --git a/pkg/llm/client.go b/pkg/llm/client.go index dc27afe4..e94a588f 100644 --- a/pkg/llm/client.go +++ b/pkg/llm/client.go @@ -1,13 +1,33 @@ package llm import ( + "context" "net/http" "time" + "github.com/mudler/LocalAGI/pkg/xlog" "github.com/sashabaranov/go-openai" ) -func NewClient(APIKey, URL, timeout string) *openai.Client { +type LLMClient interface { + CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) + CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) +} + +type realClient struct { + *openai.Client +} + +func (r *realClient) CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + return r.Client.CreateChatCompletion(ctx, req) +} + +func (r *realClient) CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) { + return r.Client.CreateImage(ctx, req) +} + +// NewClient returns a real OpenAI client as LLMClient +func NewClient(APIKey, URL, timeout string) LLMClient { // Set up OpenAI client if APIKey == "" { //log.Fatal("OPENAI_API_KEY environment variable not set") @@ -18,11 +38,12 @@ func NewClient(APIKey, URL, timeout string) *openai.Client { dur, err := time.ParseDuration(timeout) if err != nil { + xlog.Error("Failed to parse timeout", "error", err) dur = 150 * time.Second } config.HTTPClient = &http.Client{ Timeout: dur, } - return openai.NewClientWithConfig(config) + return &realClient{openai.NewClientWithConfig(config)} } diff --git a/pkg/llm/json.go b/pkg/llm/json.go index c4f48d1f..34d413b2 100644 --- a/pkg/llm/json.go +++ b/pkg/llm/json.go @@ -10,7 +10,7 @@ import ( "github.com/sashabaranov/go-openai/jsonschema" ) -func GenerateTypedJSONWithGuidance(ctx context.Context, client *openai.Client, guidance, model string, i jsonschema.Definition, dst any) error { +func GenerateTypedJSONWithGuidance(ctx context.Context, client LLMClient, guidance, model string, i jsonschema.Definition, dst any) error { return GenerateTypedJSONWithConversation(ctx, client, []openai.ChatCompletionMessage{ { Role: "user", @@ -19,7 +19,7 @@ func GenerateTypedJSONWithGuidance(ctx context.Context, client *openai.Client, g }, model, i, dst) } -func GenerateTypedJSONWithConversation(ctx context.Context, client *openai.Client, conv []openai.ChatCompletionMessage, model string, i jsonschema.Definition, dst any) error { +func GenerateTypedJSONWithConversation(ctx context.Context, client LLMClient, conv []openai.ChatCompletionMessage, model string, i jsonschema.Definition, dst any) error { toolName := "json" decision := openai.ChatCompletionRequest{ Model: model, diff --git a/pkg/llm/mock_client.go b/pkg/llm/mock_client.go new file mode 100644 index 00000000..52bc5278 --- /dev/null +++ b/pkg/llm/mock_client.go @@ -0,0 +1,25 @@ +package llm + +import ( + "context" + "github.com/sashabaranov/go-openai" +) + +type MockClient struct { + CreateChatCompletionFunc func(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) + CreateImageFunc func(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) +} + +func (m *MockClient) CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + if m.CreateChatCompletionFunc != nil { + return m.CreateChatCompletionFunc(ctx, req) + } + return openai.ChatCompletionResponse{}, nil +} + +func (m *MockClient) CreateImage(ctx context.Context, req openai.ImageRequest) (openai.ImageResponse, error) { + if m.CreateImageFunc != nil { + return m.CreateImageFunc(ctx, req) + } + return openai.ImageResponse{}, nil +} diff --git a/services/filters/classifier.go b/services/filters/classifier.go index e517a7cb..840663cc 100644 --- a/services/filters/classifier.go +++ b/services/filters/classifier.go @@ -8,7 +8,6 @@ import ( "github.com/mudler/LocalAGI/core/types" "github.com/mudler/LocalAGI/pkg/config" "github.com/mudler/LocalAGI/pkg/llm" - "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/jsonschema" ) @@ -16,7 +15,7 @@ const FilterClassifier = "classifier" type ClassifierFilter struct { name string - client *openai.Client + client llm.LLMClient model string description string allowOnMatch bool