From 5d651dcda5339ab0c8e917fb260c42368ca96224 Mon Sep 17 00:00:00 2001 From: rpidanny Date: Fri, 5 Jul 2024 19:04:17 +0200 Subject: [PATCH] feat: add map-reduce summary --- src/commands/search/accession.ts | 12 +- src/commands/search/papers.ts | 5 +- src/inputs/flags/summary-method.flag.ts | 20 +++ src/services/llm/llm.service.spec.ts | 32 +++- src/services/llm/llm.service.ts | 157 +++++++++--------- ...emplate.ts => question-answer.template.ts} | 0 .../summary.map-reduce.template.ts | 34 ++++ ...template.ts => summary.refine.template.ts} | 13 +- src/services/search/interfaces.ts | 2 + src/services/search/paper-search.service.ts | 13 +- 10 files changed, 198 insertions(+), 90 deletions(-) create mode 100644 src/inputs/flags/summary-method.flag.ts rename src/services/llm/prompt-templates/{map-reduce.template.ts => question-answer.template.ts} (100%) create mode 100644 src/services/llm/prompt-templates/summary.map-reduce.template.ts rename src/services/llm/prompt-templates/{summary.template.ts => summary.refine.template.ts} (86%) diff --git a/src/commands/search/accession.ts b/src/commands/search/accession.ts index 3e5a89b..0cd3f2d 100644 --- a/src/commands/search/accession.ts +++ b/src/commands/search/accession.ts @@ -15,6 +15,7 @@ import outputFlag from '../../inputs/flags/output.flag.js' import questionFlag from '../../inputs/flags/question.flag.js' import skipCaptchaFlag from '../../inputs/flags/skip-captcha.flag.js' import summaryFlag from '../../inputs/flags/summary.flag.js' +import summaryMethodFlag from '../../inputs/flags/summary-method.flag.js' import { PaperSearchService } from '../../services/search/paper-search.service.js' export default class SearchAccession extends BaseCommand { @@ -47,6 +48,7 @@ export default class SearchAccession extends BaseCommand legacy: legacyFlag, headless: headlessFlag, summary: summaryFlag, + 'summary-method': summaryMethodFlag, llm: llmProviderFlag, question: questionFlag, } @@ -88,7 +90,14 @@ export default class SearchAccession extends BaseCommand } public async run(): Promise { - const { count, output, 'accession-number-regex': filterPattern, summary, question } = this.flags + const { + count, + output, + 'accession-number-regex': filterPattern, + summary, + question, + 'summary-method': summaryMethod, + } = this.flags const { keywords } = this.args this.logger.info(`Searching papers with Accession Numbers (${filterPattern}) for: ${keywords}`) @@ -98,6 +107,7 @@ export default class SearchAccession extends BaseCommand minItemCount: count, filterPattern, summarize: summary, + summaryMethod, question, }) diff --git a/src/commands/search/papers.ts b/src/commands/search/papers.ts index 6c27cc1..1d28b03 100644 --- a/src/commands/search/papers.ts +++ b/src/commands/search/papers.ts @@ -14,6 +14,7 @@ import outputFlag from '../../inputs/flags/output.flag.js' import questionFlag from '../../inputs/flags/question.flag.js' import skipCaptchaFlag from '../../inputs/flags/skip-captcha.flag.js' import summaryFlag from '../../inputs/flags/summary.flag.js' +import summaryMethodFlag from '../../inputs/flags/summary-method.flag.js' import { PaperSearchService } from '../../services/search/paper-search.service.js' export default class SearchPapers extends BaseCommand { @@ -41,6 +42,7 @@ export default class SearchPapers extends BaseCommand { legacy: legacyFlag, headless: headlessFlag, summary: summaryFlag, + 'summary-method': summaryMethodFlag, llm: llmProviderFlag, question: questionFlag, } @@ -84,7 +86,7 @@ export default class SearchPapers extends BaseCommand { } public async run(): Promise { - const { count, output, filter, summary, question } = this.flags + const { count, output, filter, summary, question, 'summary-method': summaryMethod } = this.flags const { keywords } = this.args this.logger.info(`Searching papers for: ${keywords}`) @@ -94,6 +96,7 @@ export default class SearchPapers extends BaseCommand { minItemCount: count, filterPattern: filter, summarize: summary, + summaryMethod, question, }) diff --git a/src/inputs/flags/summary-method.flag.ts b/src/inputs/flags/summary-method.flag.ts new file mode 100644 index 0000000..33056f0 --- /dev/null +++ b/src/inputs/flags/summary-method.flag.ts @@ -0,0 +1,20 @@ +import * as oclif from '@oclif/core' + +import { SummaryMethod } from '../../services/llm/llm.service' + +export default oclif.Flags.custom({ + summary: 'The method to use for generating summaries.', + description: 'See FAQ for differences between methods.', + options: Object.values(SummaryMethod) as string[], + helpValue: Object.values(SummaryMethod).join('|'), + default: SummaryMethod.MapReduce, + parse: async (input: string): Promise => { + if (Object.values(SummaryMethod).includes(input as SummaryMethod)) { + return input as SummaryMethod + } else { + throw new Error( + `Invalid Summary Method : ${input}. Must be one of ${Object.values(SummaryMethod).join(', ')}`, + ) + } + }, +})() diff --git a/src/services/llm/llm.service.spec.ts b/src/services/llm/llm.service.spec.ts index ba8b917..37c82e1 100644 --- a/src/services/llm/llm.service.spec.ts +++ b/src/services/llm/llm.service.spec.ts @@ -2,7 +2,7 @@ import { jest } from '@jest/globals' import { BaseLanguageModel } from '@langchain/core/language_models/base' import { mock } from 'jest-mock-extended' -import { LLMService } from './llm.service' +import { LLMService, SummaryMethod } from './llm.service' describe('LLMService', () => { const mockBaseLanguageModel = mock() @@ -27,7 +27,9 @@ describe('LLMService', () => { const inputText = 'input text' mockBaseLanguageModel.invoke.mockResolvedValue('summary') - await expect(llmService.summarize(inputText)).resolves.toEqual('summary') + await expect(llmService.summarize(inputText, SummaryMethod.Refine)).resolves.toEqual( + 'summary', + ) expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(1) }) @@ -36,11 +38,35 @@ describe('LLMService', () => { const inputText = 'input text'.repeat(10_000) mockBaseLanguageModel.invoke.mockResolvedValue('summary') - await expect(llmService.summarize(inputText)).resolves.toEqual('summary') + await expect(llmService.summarize(inputText, SummaryMethod.Refine)).resolves.toEqual( + 'summary', + ) // 2 calls for each chunk and 1 call for final summary expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(3) }) + + it('should call llm once for short text with map_reduce method', async () => { + const inputText = 'input text' + mockBaseLanguageModel.invoke.mockResolvedValue('summary') + + await expect(llmService.summarize(inputText, SummaryMethod.MapReduce)).resolves.toEqual( + 'summary', + ) + + expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(11) + }) + + it('should call llm n times for longer text with map_reduce method', async () => { + const inputText = 'input text'.repeat(10_000) + mockBaseLanguageModel.invoke.mockResolvedValue('summary') + + await expect(llmService.summarize(inputText, SummaryMethod.MapReduce)).resolves.toEqual( + 'summary', + ) + + expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(31) + }) }) describe('ask', () => { diff --git a/src/services/llm/llm.service.ts b/src/services/llm/llm.service.ts index 0e87b9e..fbc284f 100644 --- a/src/services/llm/llm.service.ts +++ b/src/services/llm/llm.service.ts @@ -13,126 +13,129 @@ import { import { TokenTextSplitter } from 'langchain/text_splitter' import { Service } from 'typedi' -import { MAP_PROMPT, REDUCE_PROMPT } from './prompt-templates/map-reduce.template.js' -import { SUMMARY_PROMPT, SUMMARY_REFINE_PROMPT } from './prompt-templates/summary.template.js' +import * as qaTemplate from './prompt-templates/question-answer.template.js' +import * as summaryMapReduceTemplate from './prompt-templates/summary.map-reduce.template.js' +import * as summaryRefineTemplate from './prompt-templates/summary.refine.template.js' + +export enum SummaryMethod { + Refine = 'refine', + MapReduce = 'map_reduce', +} + +type TChain = RefineDocumentsChain | MapReduceDocumentsChain | StuffDocumentsChain @Service() export class LLMService { - summarizeChain!: RefineDocumentsChain | MapReduceDocumentsChain | StuffDocumentsChain - qaChain!: RefineDocumentsChain | MapReduceDocumentsChain | StuffDocumentsChain - - textSplitter!: TokenTextSplitter + private summarizationChains: { [key in SummaryMethod]: TChain } + private qaChain: TChain + private textSplitter: TokenTextSplitter constructor( readonly llm: BaseLanguageModel, private readonly logger?: Quill, ) { - this.textSplitter = new TokenTextSplitter({ - chunkSize: 10_000, - chunkOverlap: 500, - }) - - this.summarizeChain = loadSummarizationChain(llm, { - type: 'refine', - verbose: false, - questionPrompt: SUMMARY_PROMPT, - refinePrompt: SUMMARY_REFINE_PROMPT, - }) + this.textSplitter = new TokenTextSplitter({ chunkSize: 10_000, chunkOverlap: 500 }) + + this.summarizationChains = { + [SummaryMethod.Refine]: loadSummarizationChain(llm, { + type: 'refine', + verbose: false, + questionPrompt: summaryRefineTemplate.SUMMARY_PROMPT, + refinePrompt: summaryRefineTemplate.SUMMARY_REFINE_PROMPT, + }), + [SummaryMethod.MapReduce]: loadSummarizationChain(llm, { + type: 'map_reduce', + verbose: false, + combineMapPrompt: summaryMapReduceTemplate.MAP_PROMPT, + combinePrompt: summaryMapReduceTemplate.REDUCE_PROMPT, + }), + } this.qaChain = loadQAMapReduceChain(llm, { verbose: false, - combineMapPrompt: MAP_PROMPT, - combinePrompt: REDUCE_PROMPT, + combineMapPrompt: qaTemplate.MAP_PROMPT, + combinePrompt: qaTemplate.REDUCE_PROMPT, }) } - public async summarize(inputText: string) { + private createProgressBar(task: string, total: number): SingleBar { const bar = new SingleBar( { clearOnComplete: true, hideCursor: true, - format: `${chalk.magenta('Summarizing')} [{bar}] {percentage}% | ETA: {eta}s | {value}/{total}`, + format: `${chalk.magenta(task)} [{bar}] {percentage}% | ETA: {eta}s | {value}/{total}`, }, Presets.shades_classic, ) - const document = new Document({ - pageContent: inputText, - }) - const docChunks = await this.textSplitter.splitDocuments([document]) + bar.start(total, 0) - this.logger?.info( - `Summarizing document with ${inputText.length} chars (${docChunks.length} chunks)`, - ) + return bar + } - bar.start(docChunks.length, 0) + private async getDocumentChunks(inputText: string): Promise { + const document = new Document({ pageContent: inputText }) + return this.textSplitter.splitDocuments([document]) + } + private async processChunks( + chain: TChain, + docChunks: Document[], + callbacks: any[], + question?: string, + ): Promise { let docCount = 0 - - const resp = await this.summarizeChain.invoke( - { - // eslint-disable-next-line camelcase - input_documents: docChunks, - }, + return chain.invoke( + // eslint-disable-next-line camelcase + { input_documents: docChunks, question }, { - callbacks: [ - { - handleLLMEnd: async () => { - bar.update(++docCount) - }, + callbacks: callbacks.map(callback => ({ + handleLLMEnd: async res => { + callback(++docCount) + if (process.env.DARWIN_GOD_MODE) { + this.logger?.debug( + `LLM Response: ${res.generations.map(g => g.map(t => t.text).join('\n')).join(',')}`, + ) + } }, - ], + })), }, ) + } - bar.stop() + public async summarize( + inputText: string, + method: SummaryMethod = SummaryMethod.MapReduce, + ): Promise { + const docChunks = await this.getDocumentChunks(inputText) + const totalSteps = docChunks.length + (method === SummaryMethod.MapReduce ? 1 : 0) - return resp.output_text + this.logger?.info( + `Summarizing document with ${inputText.length} chars (${docChunks.length} chunks)`, + ) + const bar = this.createProgressBar('Summarizing', totalSteps) + + const resp = await this.processChunks(this.summarizationChains[method], docChunks, [ + bar.update.bind(bar), + ]) + + bar.stop() + return method === SummaryMethod.MapReduce ? resp.text : resp.output_text } public async ask(inputText: string, question: string): Promise { - const bar = new SingleBar( - { - clearOnComplete: true, - hideCursor: true, - format: `${chalk.magenta('Querying')} [{bar}] {percentage}% | ETA: {eta}s | {value}/{total}`, - }, - Presets.shades_classic, - ) + const docChunks = await this.getDocumentChunks(inputText) - const document = new Document({ - pageContent: inputText, - }) - const docChunks = await this.textSplitter.splitDocuments([document]) + const totalSteps = docChunks.length + 1 this.logger?.info( `Querying "${question}" on document with ${inputText.length} chars (${docChunks.length} chunks)`, ) + const bar = this.createProgressBar('Querying', totalSteps) - // n map + 1 reduce - bar.start(docChunks.length + 1, 0) - - let docCount = 0 - - const resp = await this.qaChain.invoke( - { - // eslint-disable-next-line camelcase - input_documents: docChunks, - question, - }, - { - callbacks: [ - { - handleLLMEnd: async () => { - bar.update(++docCount) - }, - }, - ], - }, - ) + const resp = await this.processChunks(this.qaChain, docChunks, [bar.update.bind(bar)], question) bar.stop() - return resp.text } } diff --git a/src/services/llm/prompt-templates/map-reduce.template.ts b/src/services/llm/prompt-templates/question-answer.template.ts similarity index 100% rename from src/services/llm/prompt-templates/map-reduce.template.ts rename to src/services/llm/prompt-templates/question-answer.template.ts diff --git a/src/services/llm/prompt-templates/summary.map-reduce.template.ts b/src/services/llm/prompt-templates/summary.map-reduce.template.ts new file mode 100644 index 0000000..badf817 --- /dev/null +++ b/src/services/llm/prompt-templates/summary.map-reduce.template.ts @@ -0,0 +1,34 @@ +import { PromptTemplate } from '@langchain/core/prompts' + +export const MAP_TEMPLATE = ` +You are an expert researcher skilled in summarizing research papers. +Your task is to summarize the following research paper text: + +\`\`\`md +{text} +\`\`\` + +If the text is not from a research paper, ignore it. + +Provide a concise summary including the key ideas and findings as a paragraph. + +[IMPORTANT] Only return the summary without saying anything else. + +SUMMARY:` + +export const REDUCE_TEMPLATE = ` +You are an expert researcher skilled in summarizing research papers. +Your task is to consolidate the following summaries into a single, concise summary: + +\`\`\`txt +{text} +\`\`\` + +Provide a final summary of the main themes as a paragraph. + +[IMPORTANT] Only return the summary without saying anything else. + +CONCISE SUMMARY:` + +export const MAP_PROMPT = PromptTemplate.fromTemplate(MAP_TEMPLATE) +export const REDUCE_PROMPT = PromptTemplate.fromTemplate(REDUCE_TEMPLATE) diff --git a/src/services/llm/prompt-templates/summary.template.ts b/src/services/llm/prompt-templates/summary.refine.template.ts similarity index 86% rename from src/services/llm/prompt-templates/summary.template.ts rename to src/services/llm/prompt-templates/summary.refine.template.ts index 90b2b32..a0b66e7 100644 --- a/src/services/llm/prompt-templates/summary.template.ts +++ b/src/services/llm/prompt-templates/summary.refine.template.ts @@ -5,13 +5,14 @@ You are an expert researcher who is very good at reading and summarizing researc Your goal is to create a summary of a research paper. Below you find the text content of the paper: -\`\`\`txt +\`\`\`md {text} \`\`\` -Total output will be a summary of the paper including the key ideas, findings of the paper as a paragraph. -If the text is about cookies, cookie policy and preferences, please ignore it. +If the text is a non research paper content, ignore them. + +Total output will be a summary of the paper including the key ideas, findings of the paper as a paragraph. [IMPORTANT] Only return the summary without saying anything else. @@ -23,20 +24,22 @@ You are an expert researcher who is very good at reading and summarizing researc Your goal is to create a summary of a research paper. We have provided an existing summary up to a certain point: + \`\`\`txt {existing_answer} \`\`\` Below you find the text content of the paper: -\`\`\`txt +\`\`\`md {text} \`\`\` +If the text is a non research paper content, ignore them. + Given the new context, refine the summary to be more accurate and informative. If the context isn't useful, return the original summary. Total output will be a summary of the paper including the key ideas, findings of the paper as a paragraph. -If the text is about cookies, cookie policy and preferences, please ignore it and return the original summary. [IMPORTANT] Only return the summary without saying anything else. diff --git a/src/services/search/interfaces.ts b/src/services/search/interfaces.ts index 6efdfd9..0241505 100644 --- a/src/services/search/interfaces.ts +++ b/src/services/search/interfaces.ts @@ -1,6 +1,7 @@ import { ICitation, IPaperSource } from '@rpidanny/google-scholar' import { ITextMatch } from '../../utils/text/interfaces' +import { SummaryMethod } from '../llm/llm.service' export interface IPaperEntity { title: string @@ -19,6 +20,7 @@ export interface ISearchOptions { minItemCount: number filterPattern?: string summarize?: boolean + summaryMethod?: SummaryMethod question?: string onData?: (data: IPaperEntity) => Promise } diff --git a/src/services/search/paper-search.service.ts b/src/services/search/paper-search.service.ts index 41a6f8e..8c8b0c6 100644 --- a/src/services/search/paper-search.service.ts +++ b/src/services/search/paper-search.service.ts @@ -26,6 +26,7 @@ export class PaperSearchService { minItemCount, filterPattern, summarize, + summaryMethod, question, onData, }: ISearchOptions): Promise { @@ -34,7 +35,12 @@ export class PaperSearchService { await this.googleScholar.iteratePapers( { keywords }, async paper => { - const entity = await this.processPaper(paper, { filterPattern, summarize, question }) + const entity = await this.processPaper(paper, { + filterPattern, + summarize, + summaryMethod, + question, + }) if (!entity) return true papers.push(entity) @@ -91,8 +97,9 @@ export class PaperSearchService { { filterPattern, summarize, + summaryMethod, question, - }: Pick, + }: Pick, ): Promise { const entity = this.toEntity(paper) @@ -109,7 +116,7 @@ export class PaperSearchService { } if (summarize) { - entity.summary = await this.llmService.summarize(textContent) + entity.summary = await this.llmService.summarize(textContent, summaryMethod) } if (question) {