Skip to content

Commit

Permalink
feat: add question flag to ask questions about papers
Browse files Browse the repository at this point in the history
  • Loading branch information
rpidanny committed Jul 4, 2024
1 parent 9e2607b commit 329394b
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 10 deletions.
5 changes: 4 additions & 1 deletion src/commands/search/accession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import headlessFlag from '../../inputs/flags/headless.flag.js'
import legacyFlag from '../../inputs/flags/legacy.flag.js'
import llmProviderFlag from '../../inputs/flags/llm-provider.flag.js'
import outputFlag from '../../inputs/flags/output.flag.js'
import questionFlag from '../../inputs/flags/question.flag.js'
import skipCaptchaFlag from '../../inputs/flags/skip-captcha.flag.js'
import summaryFlag from '../../inputs/flags/summary.flag.js'
import { PaperSearchService } from '../../services/search/paper-search.service.js'
Expand Down Expand Up @@ -47,6 +48,7 @@ export default class SearchAccession extends BaseCommand<typeof SearchAccession>
headless: headlessFlag,
summary: summaryFlag,
llm: llmProviderFlag,
question: questionFlag,
}

async init(): Promise<void> {
Expand Down Expand Up @@ -86,7 +88,7 @@ export default class SearchAccession extends BaseCommand<typeof SearchAccession>
}

public async run(): Promise<void> {
const { count, output, 'accession-number-regex': filterPattern, summary } = this.flags
const { count, output, 'accession-number-regex': filterPattern, summary, question } = this.flags
const { keywords } = this.args

this.logger.info(`Searching papers with Accession Numbers (${filterPattern}) for: ${keywords}`)
Expand All @@ -96,6 +98,7 @@ export default class SearchAccession extends BaseCommand<typeof SearchAccession>
minItemCount: count,
filterPattern,
summarize: summary,
question,
})

this.logger.info(`Exported papers list to: ${outputPath}`)
Expand Down
7 changes: 6 additions & 1 deletion src/commands/search/papers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import headlessFlag from '../../inputs/flags/headless.flag.js'
import legacyFlag from '../../inputs/flags/legacy.flag.js'
import llmProviderFlag from '../../inputs/flags/llm-provider.flag.js'
import outputFlag from '../../inputs/flags/output.flag.js'
import questionFlag from '../../inputs/flags/question.flag.js'
import skipCaptchaFlag from '../../inputs/flags/skip-captcha.flag.js'
import summaryFlag from '../../inputs/flags/summary.flag.js'
import { PaperSearchService } from '../../services/search/paper-search.service.js'
Expand Down Expand Up @@ -41,6 +42,7 @@ export default class SearchPapers extends BaseCommand<typeof SearchPapers> {
headless: headlessFlag,
summary: summaryFlag,
llm: llmProviderFlag,
question: questionFlag,
}

async init(): Promise<void> {
Expand All @@ -50,6 +52,7 @@ export default class SearchPapers extends BaseCommand<typeof SearchPapers> {
headless,
concurrency,
summary,
question,
llm: llmProvider,
'skip-captcha': skipCaptcha,
legacy,
Expand All @@ -60,6 +63,7 @@ export default class SearchPapers extends BaseCommand<typeof SearchPapers> {
headless,
concurrency,
summary,
question,
llmProvider,
skipCaptcha,
legacy,
Expand All @@ -80,7 +84,7 @@ export default class SearchPapers extends BaseCommand<typeof SearchPapers> {
}

public async run(): Promise<void> {
const { count, output, filter, summary } = this.flags
const { count, output, filter, summary, question } = this.flags
const { keywords } = this.args

this.logger.info(`Searching papers for: ${keywords}`)
Expand All @@ -90,6 +94,7 @@ export default class SearchPapers extends BaseCommand<typeof SearchPapers> {
minItemCount: count,
filterPattern: filter,
summarize: summary,
question,
})

this.logger.info(`Exported papers list to: ${outputFile}`)
Expand Down
5 changes: 3 additions & 2 deletions src/containers/search.container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,22 @@ export function initSearchContainer(
headless: boolean
concurrency: number
summary: boolean
question?: string
llmProvider: LLMProvider
skipCaptcha: boolean
legacy: boolean
},
config: TConfig,
logger: Quill,
) {
const { headless, concurrency, summary, llmProvider, skipCaptcha, legacy } = opts
const { headless, concurrency, summary, llmProvider, skipCaptcha, legacy, question } = opts

Container.set(
Odysseus,
new Odysseus({ headless, waitOnCaptcha: true, initHtml: getInitPageContent() }),
)
Container.set(Quill, logger)
Container.set(PaperSearchConfig, { concurrency: summary ? 1 : concurrency })
Container.set(PaperSearchConfig, { concurrency: summary || question != null ? 1 : concurrency })
Container.set(PaperServiceConfig, {
skipCaptcha,
legacyProcessing: legacy,
Expand Down
1 change: 1 addition & 0 deletions src/inputs/flags/char.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ export enum FlagChar {
Headless = 'h',
IncludeSummary = 'S',
LogLevel = 'l',
Question = 'q',
}
9 changes: 9 additions & 0 deletions src/inputs/flags/question.flag.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import * as oclif from '@oclif/core'

import { FlagChar } from './char.js'

export default oclif.Flags.string({
char: FlagChar.Question,
helpValue: 'STRING',
summary: 'The question to ask the language model about the text content.',
})
48 changes: 48 additions & 0 deletions src/services/llm/llm.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { Quill } from '@rpidanny/quill'
import chalk from 'chalk'
import { Presets, SingleBar } from 'cli-progress'
import {
loadQAMapReduceChain,
loadSummarizationChain,
MapReduceDocumentsChain,
RefineDocumentsChain,
Expand All @@ -17,6 +18,8 @@ import { SUMMARY_PROMPT, SUMMARY_REFINE_PROMPT } from './prompt-templates/summar
@Service()
export class LLMService {
summarizeChain!: RefineDocumentsChain | MapReduceDocumentsChain | StuffDocumentsChain
qaChain!: RefineDocumentsChain | MapReduceDocumentsChain | StuffDocumentsChain

textSplitter!: TokenTextSplitter

constructor(
Expand All @@ -34,6 +37,8 @@ export class LLMService {
questionPrompt: SUMMARY_PROMPT,
refinePrompt: SUMMARY_REFINE_PROMPT,
})

this.qaChain = loadQAMapReduceChain(llm)
}

public async summarize(inputText: string) {
Expand Down Expand Up @@ -79,4 +84,47 @@ export class LLMService {

return resp.output_text
}

public async ask(inputText: string, question: string): Promise<string> {
const bar = new SingleBar(
{
clearOnComplete: false,
hideCursor: true,
format: `${chalk.magenta('Processing')} [{bar}] {percentage}% | ETA: {eta}s | {value}/{total}`,
},
Presets.shades_classic,
)

const document = new Document({
pageContent: inputText,
})
const docChunks = await this.textSplitter.splitDocuments([document])

this.logger?.info(`QA ${inputText.length} char (${docChunks.length} chunks) document...`)

bar.start(docChunks.length, 0)

let docCount = 0

const resp = await this.qaChain.invoke(
{
// eslint-disable-next-line camelcase
input_documents: docChunks,
question,
},
{
callbacks: [
{
handleLLMEnd: async () => {
bar.update(++docCount)
},
},
],
},
)

bar.stop()

return resp.text
}
}
2 changes: 2 additions & 0 deletions src/services/search/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ export interface IPaperEntity {
authors: string[]
matches?: ITextMatch[]
summary?: string
answer?: string
}

export interface ISearchOptions {
keywords: string
minItemCount: number
filterPattern?: string
summarize?: boolean
question?: string
onData?: (data: IPaperEntity) => Promise<any>
}
24 changes: 24 additions & 0 deletions src/services/search/paper-search.service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,30 @@ describe('PaperSearchService', () => {
})),
)
})

it('should ask question and store answer if question is provided', async () => {
const answer = 'transformer is better than RNN, duhhhh'
llmService.ask.mockResolvedValue(answer)

const entities = await service.search({
keywords: 'some keywords',
minItemCount: 10,
question: 'Is it better than RNN?',
})

expect(entities).toHaveLength(3)
expect(entities).toEqual(
page.papers.map(result => ({
title: result.title,
authors: result.authors.map(author => author.name),
description: result.description,
url: result.url,
citation: result.citation,
source: result.source,
answer,
})),
)
})
})

describe('exportToCSV', () => {
Expand Down
19 changes: 13 additions & 6 deletions src/services/search/paper-search.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ export class PaperSearchService {
minItemCount,
filterPattern,
summarize,
question,
onData,
}: ISearchOptions): Promise<IPaperEntity[]> {
const papers: IPaperEntity[] = []

await this.googleScholar.iteratePapers(
{ keywords },
async paper => {
const entity = await this.processPaper(paper, filterPattern, summarize)
const entity = await this.processPaper(paper, { filterPattern, summarize, question })
if (!entity) return true

papers.push(entity)
Expand Down Expand Up @@ -78,12 +79,15 @@ export class PaperSearchService {

private async processPaper(
paper: IPaperMetadata,
filterPattern?: string,
summarize?: boolean,
{
filterPattern,
summarize,
question,
}: Pick<ISearchOptions, 'filterPattern' | 'summarize' | 'question'>,
): Promise<IPaperEntity | undefined> {
const entity = this.toEntity(paper)

if (!filterPattern && !summarize) return entity
if (!filterPattern && !summarize && !question) return entity

try {
const textContent = await this.paperService.getTextContent(paper)
Expand All @@ -96,8 +100,11 @@ export class PaperSearchService {
}

if (summarize) {
const summary = await this.llmService.summarize(textContent)
entity.summary = summary
entity.summary = await this.llmService.summarize(textContent)
}

if (question) {
entity.answer = await this.llmService.ask(textContent, question)
}

return entity
Expand Down

0 comments on commit 329394b

Please sign in to comment.