From 9306b87873759a1043c474a33aefeaa89c5a1ba0 Mon Sep 17 00:00:00 2001 From: rpidanny Date: Thu, 4 Jul 2024 16:54:55 +0200 Subject: [PATCH] test: add llm service basic test --- jest.config.ts | 8 ++-- src/services/llm/llm.service.spec.ts | 71 ++++++++++++++++++++++++++++ src/services/llm/llm.service.ts | 4 +- 3 files changed, 79 insertions(+), 4 deletions(-) create mode 100644 src/services/llm/llm.service.spec.ts diff --git a/jest.config.ts b/jest.config.ts index d74d8cc..b0d9a05 100644 --- a/jest.config.ts +++ b/jest.config.ts @@ -43,13 +43,15 @@ const config: JestConfigWithTsJest = { '/coverage/', '/src/commands/', '\\.config\\.ts$', + '/src/services/chat/autonomous-agent.ts', + '/src/utils/ui/output.ts', ], coverageThreshold: { global: { - statements: 81, + statements: 95, branches: 90, - functions: 90, - lines: 81, + functions: 94, + lines: 95, }, }, } diff --git a/src/services/llm/llm.service.spec.ts b/src/services/llm/llm.service.spec.ts new file mode 100644 index 0000000..ba8b917 --- /dev/null +++ b/src/services/llm/llm.service.spec.ts @@ -0,0 +1,71 @@ +import { jest } from '@jest/globals' +import { BaseLanguageModel } from '@langchain/core/language_models/base' +import { mock } from 'jest-mock-extended' + +import { LLMService } from './llm.service' + +describe('LLMService', () => { + const mockBaseLanguageModel = mock() + + let llmService: LLMService + + beforeEach(() => { + llmService = new LLMService( + mock({ + pipe: () => mockBaseLanguageModel, + }), + ) + }) + + afterEach(() => { + jest.clearAllMocks() + jest.resetAllMocks() + }) + + describe('summarize', () => { + it('should call llm once for short text', async () => { + const inputText = 'input text' + mockBaseLanguageModel.invoke.mockResolvedValue('summary') + + await expect(llmService.summarize(inputText)).resolves.toEqual('summary') + + expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(1) + }) + + it('should call llm n times for longer text', async () => { + const inputText = 'input text'.repeat(10_000) + mockBaseLanguageModel.invoke.mockResolvedValue('summary') + + await expect(llmService.summarize(inputText)).resolves.toEqual('summary') + + // 2 calls for each chunk and 1 call for final summary + expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(3) + }) + }) + + describe('ask', () => { + it('should call llm once', async () => { + const inputText = 'input text' + const question = 'question' + + mockBaseLanguageModel.invoke.mockResolvedValue('answer') + mockBaseLanguageModel.getNumTokens.mockResolvedValue(3) + + await expect(llmService.ask(inputText, question)).resolves.toEqual('answer') + + expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(11) + }) + + it('should call llm n times for longer text', async () => { + const inputText = 'input text'.repeat(10_000) + const question = 'question' + + mockBaseLanguageModel.invoke.mockResolvedValue('answer') + mockBaseLanguageModel.getNumTokens.mockResolvedValue(3) + + await expect(llmService.ask(inputText, question)).resolves.toEqual('answer') + + expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(31) + }) + }) +}) diff --git a/src/services/llm/llm.service.ts b/src/services/llm/llm.service.ts index c58db58..7d4a9b5 100644 --- a/src/services/llm/llm.service.ts +++ b/src/services/llm/llm.service.ts @@ -38,7 +38,9 @@ export class LLMService { refinePrompt: SUMMARY_REFINE_PROMPT, }) - this.qaChain = loadQAMapReduceChain(llm) + this.qaChain = loadQAMapReduceChain(llm, { + verbose: false, + }) } public async summarize(inputText: string) {