Skip to content

Commit

Permalink
feat: add map-reduce summary
Browse files Browse the repository at this point in the history
  • Loading branch information
rpidanny committed Jul 5, 2024
1 parent 90708dc commit 5d651dc
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 90 deletions.
12 changes: 11 additions & 1 deletion src/commands/search/accession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import outputFlag from '../../inputs/flags/output.flag.js'
import questionFlag from '../../inputs/flags/question.flag.js'
import skipCaptchaFlag from '../../inputs/flags/skip-captcha.flag.js'
import summaryFlag from '../../inputs/flags/summary.flag.js'
import summaryMethodFlag from '../../inputs/flags/summary-method.flag.js'
import { PaperSearchService } from '../../services/search/paper-search.service.js'

export default class SearchAccession extends BaseCommand<typeof SearchAccession> {
Expand Down Expand Up @@ -47,6 +48,7 @@ export default class SearchAccession extends BaseCommand<typeof SearchAccession>
legacy: legacyFlag,
headless: headlessFlag,
summary: summaryFlag,
'summary-method': summaryMethodFlag,
llm: llmProviderFlag,
question: questionFlag,
}
Expand Down Expand Up @@ -88,7 +90,14 @@ export default class SearchAccession extends BaseCommand<typeof SearchAccession>
}

public async run(): Promise<void> {
const { count, output, 'accession-number-regex': filterPattern, summary, question } = this.flags
const {
count,
output,
'accession-number-regex': filterPattern,
summary,
question,
'summary-method': summaryMethod,
} = this.flags
const { keywords } = this.args

this.logger.info(`Searching papers with Accession Numbers (${filterPattern}) for: ${keywords}`)
Expand All @@ -98,6 +107,7 @@ export default class SearchAccession extends BaseCommand<typeof SearchAccession>
minItemCount: count,
filterPattern,
summarize: summary,
summaryMethod,
question,
})

Expand Down
5 changes: 4 additions & 1 deletion src/commands/search/papers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import outputFlag from '../../inputs/flags/output.flag.js'
import questionFlag from '../../inputs/flags/question.flag.js'
import skipCaptchaFlag from '../../inputs/flags/skip-captcha.flag.js'
import summaryFlag from '../../inputs/flags/summary.flag.js'
import summaryMethodFlag from '../../inputs/flags/summary-method.flag.js'
import { PaperSearchService } from '../../services/search/paper-search.service.js'

export default class SearchPapers extends BaseCommand<typeof SearchPapers> {
Expand Down Expand Up @@ -41,6 +42,7 @@ export default class SearchPapers extends BaseCommand<typeof SearchPapers> {
legacy: legacyFlag,
headless: headlessFlag,
summary: summaryFlag,
'summary-method': summaryMethodFlag,
llm: llmProviderFlag,
question: questionFlag,
}
Expand Down Expand Up @@ -84,7 +86,7 @@ export default class SearchPapers extends BaseCommand<typeof SearchPapers> {
}

public async run(): Promise<void> {
const { count, output, filter, summary, question } = this.flags
const { count, output, filter, summary, question, 'summary-method': summaryMethod } = this.flags
const { keywords } = this.args

this.logger.info(`Searching papers for: ${keywords}`)
Expand All @@ -94,6 +96,7 @@ export default class SearchPapers extends BaseCommand<typeof SearchPapers> {
minItemCount: count,
filterPattern: filter,
summarize: summary,
summaryMethod,
question,
})

Expand Down
20 changes: 20 additions & 0 deletions src/inputs/flags/summary-method.flag.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import * as oclif from '@oclif/core'

import { SummaryMethod } from '../../services/llm/llm.service'

export default oclif.Flags.custom<SummaryMethod>({
summary: 'The method to use for generating summaries.',
description: 'See FAQ for differences between methods.',
options: Object.values(SummaryMethod) as string[],
helpValue: Object.values(SummaryMethod).join('|'),
default: SummaryMethod.MapReduce,
parse: async (input: string): Promise<SummaryMethod> => {
if (Object.values(SummaryMethod).includes(input as SummaryMethod)) {
return input as SummaryMethod
} else {
throw new Error(
`Invalid Summary Method : ${input}. Must be one of ${Object.values(SummaryMethod).join(', ')}`,
)
}
},
})()
32 changes: 29 additions & 3 deletions src/services/llm/llm.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { jest } from '@jest/globals'
import { BaseLanguageModel } from '@langchain/core/language_models/base'
import { mock } from 'jest-mock-extended'

import { LLMService } from './llm.service'
import { LLMService, SummaryMethod } from './llm.service'

describe('LLMService', () => {
const mockBaseLanguageModel = mock<BaseLanguageModel>()
Expand All @@ -27,7 +27,9 @@ describe('LLMService', () => {
const inputText = 'input text'
mockBaseLanguageModel.invoke.mockResolvedValue('summary')

await expect(llmService.summarize(inputText)).resolves.toEqual('summary')
await expect(llmService.summarize(inputText, SummaryMethod.Refine)).resolves.toEqual(
'summary',
)

expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(1)
})
Expand All @@ -36,11 +38,35 @@ describe('LLMService', () => {
const inputText = 'input text'.repeat(10_000)
mockBaseLanguageModel.invoke.mockResolvedValue('summary')

await expect(llmService.summarize(inputText)).resolves.toEqual('summary')
await expect(llmService.summarize(inputText, SummaryMethod.Refine)).resolves.toEqual(
'summary',
)

// 2 calls for each chunk and 1 call for final summary
expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(3)
})

it('should call llm once for short text with map_reduce method', async () => {
const inputText = 'input text'
mockBaseLanguageModel.invoke.mockResolvedValue('summary')

await expect(llmService.summarize(inputText, SummaryMethod.MapReduce)).resolves.toEqual(
'summary',
)

expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(11)
})

it('should call llm n times for longer text with map_reduce method', async () => {
const inputText = 'input text'.repeat(10_000)
mockBaseLanguageModel.invoke.mockResolvedValue('summary')

await expect(llmService.summarize(inputText, SummaryMethod.MapReduce)).resolves.toEqual(
'summary',
)

expect(mockBaseLanguageModel.invoke).toHaveBeenCalledTimes(31)
})
})

describe('ask', () => {
Expand Down
157 changes: 80 additions & 77 deletions src/services/llm/llm.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,126 +13,129 @@ import {
import { TokenTextSplitter } from 'langchain/text_splitter'
import { Service } from 'typedi'

import { MAP_PROMPT, REDUCE_PROMPT } from './prompt-templates/map-reduce.template.js'
import { SUMMARY_PROMPT, SUMMARY_REFINE_PROMPT } from './prompt-templates/summary.template.js'
import * as qaTemplate from './prompt-templates/question-answer.template.js'
import * as summaryMapReduceTemplate from './prompt-templates/summary.map-reduce.template.js'
import * as summaryRefineTemplate from './prompt-templates/summary.refine.template.js'

export enum SummaryMethod {
Refine = 'refine',
MapReduce = 'map_reduce',
}

type TChain = RefineDocumentsChain | MapReduceDocumentsChain | StuffDocumentsChain

@Service()
export class LLMService {
summarizeChain!: RefineDocumentsChain | MapReduceDocumentsChain | StuffDocumentsChain
qaChain!: RefineDocumentsChain | MapReduceDocumentsChain | StuffDocumentsChain

textSplitter!: TokenTextSplitter
private summarizationChains: { [key in SummaryMethod]: TChain }
private qaChain: TChain
private textSplitter: TokenTextSplitter

constructor(
readonly llm: BaseLanguageModel,
private readonly logger?: Quill,
) {
this.textSplitter = new TokenTextSplitter({
chunkSize: 10_000,
chunkOverlap: 500,
})

this.summarizeChain = loadSummarizationChain(llm, {
type: 'refine',
verbose: false,
questionPrompt: SUMMARY_PROMPT,
refinePrompt: SUMMARY_REFINE_PROMPT,
})
this.textSplitter = new TokenTextSplitter({ chunkSize: 10_000, chunkOverlap: 500 })

this.summarizationChains = {
[SummaryMethod.Refine]: loadSummarizationChain(llm, {
type: 'refine',
verbose: false,
questionPrompt: summaryRefineTemplate.SUMMARY_PROMPT,
refinePrompt: summaryRefineTemplate.SUMMARY_REFINE_PROMPT,
}),
[SummaryMethod.MapReduce]: loadSummarizationChain(llm, {
type: 'map_reduce',
verbose: false,
combineMapPrompt: summaryMapReduceTemplate.MAP_PROMPT,
combinePrompt: summaryMapReduceTemplate.REDUCE_PROMPT,
}),
}

this.qaChain = loadQAMapReduceChain(llm, {
verbose: false,
combineMapPrompt: MAP_PROMPT,
combinePrompt: REDUCE_PROMPT,
combineMapPrompt: qaTemplate.MAP_PROMPT,
combinePrompt: qaTemplate.REDUCE_PROMPT,
})
}

public async summarize(inputText: string) {
private createProgressBar(task: string, total: number): SingleBar {
const bar = new SingleBar(
{
clearOnComplete: true,
hideCursor: true,
format: `${chalk.magenta('Summarizing')} [{bar}] {percentage}% | ETA: {eta}s | {value}/{total}`,
format: `${chalk.magenta(task)} [{bar}] {percentage}% | ETA: {eta}s | {value}/{total}`,
},
Presets.shades_classic,
)

const document = new Document({
pageContent: inputText,
})
const docChunks = await this.textSplitter.splitDocuments([document])
bar.start(total, 0)

this.logger?.info(
`Summarizing document with ${inputText.length} chars (${docChunks.length} chunks)`,
)
return bar
}

bar.start(docChunks.length, 0)
private async getDocumentChunks(inputText: string): Promise<Document[]> {
const document = new Document({ pageContent: inputText })
return this.textSplitter.splitDocuments([document])
}

private async processChunks(
chain: TChain,
docChunks: Document[],
callbacks: any[],
question?: string,
): Promise<any> {
let docCount = 0

const resp = await this.summarizeChain.invoke(
{
// eslint-disable-next-line camelcase
input_documents: docChunks,
},
return chain.invoke(
// eslint-disable-next-line camelcase
{ input_documents: docChunks, question },
{
callbacks: [
{
handleLLMEnd: async () => {
bar.update(++docCount)
},
callbacks: callbacks.map(callback => ({
handleLLMEnd: async res => {
callback(++docCount)
if (process.env.DARWIN_GOD_MODE) {
this.logger?.debug(
`LLM Response: ${res.generations.map(g => g.map(t => t.text).join('\n')).join(',')}`,
)
}
},
],
})),
},
)
}

bar.stop()
public async summarize(
inputText: string,
method: SummaryMethod = SummaryMethod.MapReduce,
): Promise<string> {
const docChunks = await this.getDocumentChunks(inputText)
const totalSteps = docChunks.length + (method === SummaryMethod.MapReduce ? 1 : 0)

return resp.output_text
this.logger?.info(
`Summarizing document with ${inputText.length} chars (${docChunks.length} chunks)`,
)
const bar = this.createProgressBar('Summarizing', totalSteps)

const resp = await this.processChunks(this.summarizationChains[method], docChunks, [
bar.update.bind(bar),
])

bar.stop()
return method === SummaryMethod.MapReduce ? resp.text : resp.output_text
}

public async ask(inputText: string, question: string): Promise<string> {
const bar = new SingleBar(
{
clearOnComplete: true,
hideCursor: true,
format: `${chalk.magenta('Querying')} [{bar}] {percentage}% | ETA: {eta}s | {value}/{total}`,
},
Presets.shades_classic,
)
const docChunks = await this.getDocumentChunks(inputText)

const document = new Document({
pageContent: inputText,
})
const docChunks = await this.textSplitter.splitDocuments([document])
const totalSteps = docChunks.length + 1

this.logger?.info(
`Querying "${question}" on document with ${inputText.length} chars (${docChunks.length} chunks)`,
)
const bar = this.createProgressBar('Querying', totalSteps)

// n map + 1 reduce
bar.start(docChunks.length + 1, 0)

let docCount = 0

const resp = await this.qaChain.invoke(
{
// eslint-disable-next-line camelcase
input_documents: docChunks,
question,
},
{
callbacks: [
{
handleLLMEnd: async () => {
bar.update(++docCount)
},
},
],
},
)
const resp = await this.processChunks(this.qaChain, docChunks, [bar.update.bind(bar)], question)

bar.stop()

return resp.text
}
}
34 changes: 34 additions & 0 deletions src/services/llm/prompt-templates/summary.map-reduce.template.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import { PromptTemplate } from '@langchain/core/prompts'

export const MAP_TEMPLATE = `
You are an expert researcher skilled in summarizing research papers.
Your task is to summarize the following research paper text:
\`\`\`md
{text}
\`\`\`
If the text is not from a research paper, ignore it.
Provide a concise summary including the key ideas and findings as a paragraph.
[IMPORTANT] Only return the summary without saying anything else.
SUMMARY:`

export const REDUCE_TEMPLATE = `
You are an expert researcher skilled in summarizing research papers.
Your task is to consolidate the following summaries into a single, concise summary:
\`\`\`txt
{text}
\`\`\`
Provide a final summary of the main themes as a paragraph.
[IMPORTANT] Only return the summary without saying anything else.
CONCISE SUMMARY:`

export const MAP_PROMPT = PromptTemplate.fromTemplate(MAP_TEMPLATE)
export const REDUCE_PROMPT = PromptTemplate.fromTemplate(REDUCE_TEMPLATE)
Loading

0 comments on commit 5d651dc

Please sign in to comment.