Skip to content

Commit

Permalink
Add new document to passage splitting mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
Thejas Venkatesh committed Apr 3, 2023
1 parent 4e56f4d commit cf44222
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 4 deletions.
16 changes: 12 additions & 4 deletions utility/preprocess/docs2passages.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@ def process_page(inp):
else:
words = tokenizer.tokenize(content)

words_ = (words + words) if len(words) > nwords else words
passages = [words_[offset:offset + nwords] for offset in range(0, len(words) - overlap, nwords - overlap)]

assert all(len(psg) in [len(words), nwords] for psg in passages), (list(map(len, passages)), len(words))
n_passages = (len(words) + nwords - 1) // nwords
if n_passages > 1:
last_2_passage_length = len(words) - nwords * (n_passages - 2)
passage_lengths = [0] + [nwords] * (n_passages - 2) + [last_2_passage_length // 2] + [last_2_passage_length - last_2_passage_length // 2]
assert sum(passage_lengths) == len(words)
elif n_passages == 1:
passage_lengths = [0, len(words)]
else:
passage_lengths = [0]
print(n_passages, passage_lengths)
assert len(passage_lengths) == n_passages + 1
passages = [words[passage_lengths[idx-1]:passage_lengths[idx]] for idx in range(1, len(passage_lengths))]

if tokenizer is None:
passages = [' '.join(psg) for psg in passages]
Expand Down
149 changes: 149 additions & 0 deletions utility/preprocess/docs2passages_dpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Divide a document collection into N-word/token passage spans (with wrap-around for last passage).
"""

import os
import math
import ujson
import random

from multiprocessing import Pool
from argparse import ArgumentParser
from colbert.utils.utils import print_message

Format1 = 'docid,text' # MS MARCO Passages
Format2 = 'docid,text,title' # DPR Wikipedia
Format3 = 'docid,url,title,text' # MS MARCO Documents


def process_page(inp):
"""
Wraps around if we split: make sure last passage isn't too short.
This is meant to be similar to the DPR preprocessing.
"""

(nwords, overlap, tokenizer), (title_idx, docid, title, url, content) = inp

if tokenizer is None:
words = content.split()
else:
words = tokenizer.tokenize(content)

words_ = (words + words) if len(words) > nwords else words
passages = [words_[offset:offset + nwords] for offset in range(0, len(words) - overlap, nwords - overlap)]

assert all(len(psg) in [len(words), nwords] for psg in passages), (list(map(len, passages)), len(words))

if tokenizer is None:
passages = [' '.join(psg) for psg in passages]
else:
passages = [' '.join(psg).replace(' ##', '') for psg in passages]

if title_idx % 100000 == 0:
print("#> ", title_idx, '\t\t\t', title)

for p in passages:
print("$$$ ", '\t\t', p)
print()

print()
print()
print()

return (docid, title, url, passages)


def main(args):
random.seed(12345)
print_message("#> Starting...")

letter = 'w' if not args.use_wordpiece else 't'
output_path = f'{args.input}.{letter}{args.nwords}_{args.overlap}'
assert not os.path.exists(output_path)

RawCollection = []
Collection = []

NumIllFormattedLines = 0

with open(args.input) as f:
for line_idx, line in enumerate(f):
if line_idx % (100*1000) == 0:
print(line_idx, end=' ')

title, url = None, None

try:
line = line.strip().split('\t')

if args.format == Format1:
docid, doc = line
elif args.format == Format2:
docid, doc, title = line
elif args.format == Format3:
docid, url, title, doc = line

RawCollection.append((line_idx, docid, title, url, doc))
except:
NumIllFormattedLines += 1

if NumIllFormattedLines % 1000 == 0:
print(f'\n[{line_idx}] NumIllFormattedLines = {NumIllFormattedLines}\n')

print()
print_message("# of documents is", len(RawCollection), '\n')

p = Pool(args.nthreads)

print_message("#> Starting parallel processing...")

tokenizer = None
if args.use_wordpiece:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

process_page_params = [(args.nwords, args.overlap, tokenizer)] * len(RawCollection)
Collection = p.map(process_page, zip(process_page_params, RawCollection))

print_message(f"#> Writing to {output_path} ...")
with open(output_path, 'w') as f:
line_idx = 1

if args.format == Format1:
f.write('\t'.join(['id', 'text']) + '\n')
elif args.format == Format2:
f.write('\t'.join(['id', 'text', 'title']) + '\n')
elif args.format == Format3:
f.write('\t'.join(['id', 'text', 'title', 'docid']) + '\n')

for docid, title, url, passages in Collection:
for passage in passages:
if args.format == Format1:
f.write('\t'.join([str(line_idx), passage]) + '\n')
elif args.format == Format2:
f.write('\t'.join([str(line_idx), passage, title]) + '\n')
elif args.format == Format3:
f.write('\t'.join([str(line_idx), passage, title, docid]) + '\n')

line_idx += 1


if __name__ == "__main__":
parser = ArgumentParser(description="docs2passages.")

# Input Arguments.
parser.add_argument('--input', dest='input', required=True)
parser.add_argument('--format', dest='format', required=True, choices=[Format1, Format2, Format3])

# Output Arguments.
parser.add_argument('--use-wordpiece', dest='use_wordpiece', default=False, action='store_true')
parser.add_argument('--nwords', dest='nwords', default=100, type=int)
parser.add_argument('--overlap', dest='overlap', default=0, type=int)

# Other Arguments.
parser.add_argument('--nthreads', dest='nthreads', default=28, type=int)

args = parser.parse_args()
assert args.nwords in range(50, 500)

main(args)

0 comments on commit cf44222

Please sign in to comment.