Skip to content

Commit 07731f1

Browse files
author
Vladimir Dobrovolskii
committed
added prediction script
1 parent 81e66c2 commit 07731f1

File tree

3 files changed

+120
-3
lines changed

3 files changed

+120
-3
lines changed

README.md

+39-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
This is a repository with the code to reproduce the experiments described in the paper of the same name, which was accepted to EMNLP 2021. The paper is available [here](https://aclanthology.org/2021.emnlp-main.605/).
44

55
### Table of contents
6-
1. [Preparation](#preparation)
7-
2. [Training](#training)
8-
3. [Evaluation](#evaluation)
6+
- [Word-Level Coreference Resolution](#word-level-coreference-resolution)
7+
- [Table of contents](#table-of-contents)
8+
- [Preparation](#preparation)
9+
- [Training](#training)
10+
- [Evaluation](#evaluation)
11+
- [Prediction](#prediction)
12+
- [Citation](#citation)
913

1014
### Preparation
1115

@@ -67,6 +71,38 @@ Make sure that you have successfully completed all steps of the [Preparation](#p
6771

6872
python calculate_conll.py roberta test 20
6973

74+
### Prediction
75+
76+
To predict coreference relations on an arbitrary text, you will need to prepare the data in the jsonlines format (one json-formatted document per line).
77+
The following fields are requred:
78+
79+
{
80+
"document_id": "tc_mydoc_001",
81+
"cased_words": ["Hi", "!", "Bye", "."],
82+
"sent_id": [0, 0, 1, 1]
83+
}
84+
85+
You can optionally provide the speaker data:
86+
87+
{
88+
"speaker": ["Tom", "Tom", "#2", "#2"]
89+
}
90+
91+
`document_id` can be any string that starts with a two-letter genre identifier. The genres recognized are the following:
92+
* bc: broadcast conversation
93+
* bn: broadcast news
94+
* mz: magazine genre (Sinorama magazine)
95+
* nw: newswire genre
96+
* pt: pivot text (The Bible)
97+
* tc: telephone conversation (CallHome corpus)
98+
* wb: web data
99+
100+
Then run:
101+
102+
python predict.py roberta input.jsonlines output.jsonlines
103+
104+
This will utilize the latest weights available in the data directory for the chosen configuration. To load other weights, use the `--weights` argument.
105+
70106
### Citation
71107
@inproceedings{dobrovolskii-2021-word,
72108
title = "Word-Level Coreference Resolution",

predict.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import argparse
2+
3+
import jsonlines
4+
import torch
5+
from tqdm import tqdm
6+
7+
from coref import CorefModel
8+
from coref.tokenizer_customization import *
9+
10+
11+
def build_doc(doc: dict, model: CorefModel) -> dict:
12+
filter_func = TOKENIZER_FILTERS.get(model.config.bert_model,
13+
lambda _: True)
14+
token_map = TOKENIZER_MAPS.get(model.config.bert_model, {})
15+
16+
word2subword = []
17+
subwords = []
18+
word_id = []
19+
for i, word in enumerate(doc["cased_words"]):
20+
tokenized_word = (token_map[word]
21+
if word in token_map
22+
else model.tokenizer.tokenize(word))
23+
tokenized_word = list(filter(filter_func, tokenized_word))
24+
word2subword.append((len(subwords), len(subwords) + len(tokenized_word)))
25+
subwords.extend(tokenized_word)
26+
word_id.extend([i] * len(tokenized_word))
27+
doc["word2subword"] = word2subword
28+
doc["subwords"] = subwords
29+
doc["word_id"] = word_id
30+
31+
doc["head2span"] = []
32+
if "speaker" not in doc:
33+
doc["speaker"] = ["_" for _ in doc["cased_words"]]
34+
doc["word_clusters"] = []
35+
doc["span_clusters"] = []
36+
37+
return doc
38+
39+
40+
if __name__ == "__main__":
41+
argparser = argparse.ArgumentParser()
42+
argparser.add_argument("experiment")
43+
argparser.add_argument("input_file")
44+
argparser.add_argument("output_file")
45+
argparser.add_argument("--config-file", default="config.toml")
46+
argparser.add_argument("--batch-size", type=int,
47+
help="Adjust to override the config value if you're"
48+
" experiencing out-of-memory issues")
49+
argparser.add_argument("--weights",
50+
help="Path to file with weights to load."
51+
" If not supplied, in the latest"
52+
" weights of the experiment will be loaded;"
53+
" if there aren't any, an error is raised.")
54+
args = argparser.parse_args()
55+
56+
model = CorefModel(args.config_file, args.experiment)
57+
58+
if args.batch_size:
59+
model.config.a_scoring_batch_size = args.batch_size
60+
61+
model.load_weights(path=args.weights, map_location="cpu",
62+
ignore={"bert_optimizer", "general_optimizer",
63+
"bert_scheduler", "general_scheduler"})
64+
model.training = False
65+
66+
with jsonlines.open(args.input_file, mode="r") as input_data:
67+
docs = [build_doc(doc, model) for doc in input_data]
68+
69+
with torch.no_grad():
70+
for doc in tqdm(docs, unit="docs"):
71+
result = model.run(doc)
72+
doc["span_clusters"] = result.span_clusters
73+
doc["word_clusters"] = result.word_clusters
74+
75+
for key in ("word2subword", "subwords", "word_id", "head2span"):
76+
del doc[key]
77+
78+
with jsonlines.open(args.output_file, mode="w") as output_data:
79+
output_data.write_all(docs)

sample_input.jsonlines

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
{"document_id": "tc_sample_input_001", "cased_words": ["Hi", ",", "my", "name", "is", "Tom", ".", "I", "am", "five", "."], "sent_id": [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], "speaker": ["Tom", "Tom", "Tom", "Tom", "Tom", "Tom", "Tom", "Tom", "Tom", "Tom", "Tom"]}
2+
{"document_id": "pt_sample_input_001", "cased_words": ["Because", "Joseph", "her", "husband", "was", "faithful", "to", "the", "law,", "and", "yet", "did", "not", "want", "to", "expose", "her", "to", "public", "disgrace,", "he", "had", "in", "mind", "to", "divorce", "her", "quietly", "."], "sent_id": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}

0 commit comments

Comments
 (0)