Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0ecbef8

Browse files
authored
Merge pull request #363 from urvashik/master
CNN/Dailymail Summarization
2 parents 3c5823f + f084b5b commit 0ecbef8

File tree

1 file changed

+95
-24
lines changed

1 file changed

+95
-24
lines changed

tensor2tensor/data_generators/cnn_dailymail.py

Lines changed: 95 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import os
2323
import tarfile
24+
import hashlib
2425

2526
# Dependency imports
2627

@@ -38,19 +39,31 @@
3839

3940
_DAILYMAIL_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs"
4041

42+
# Note: using See et al. (2017) as reference for data generation
43+
# For more info, use the links below
44+
45+
# Train/Dev/Test Splits for summarization data
46+
_TRAIN_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt"
47+
_DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt"
48+
_TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt"
4149

4250
# End-of-sentence marker.
4351
EOS = text_encoder.EOS_ID
4452

53+
# Techniques for data prep from See et al. (2017)
54+
dm_single_close_quote = u'\u2019' # unicode
55+
dm_double_close_quote = u'\u201d'
56+
END_TOKENS = [u'.', u'!', u'?', u'...', u"'", u"`", u'"', dm_single_close_quote, dm_double_close_quote, u")"] # acceptable ways to end a sentence
57+
4558

46-
def _maybe_download_corpora(tmp_dir):
59+
def _maybe_download_corpora(tmp_dir, is_training):
4760
"""Download corpora if necessary and unzip them.
4861
4962
Args:
5063
tmp_dir: directory containing dataset.
5164
5265
Returns:
53-
filepath of the downloaded corpus file.
66+
list of all files generated and path to file containing train/dev/test split info.
5467
"""
5568
cnn_filename = "cnn_stories.tgz"
5669
cnn_finalpath = os.path.join(tmp_dir, "cnn/stories/")
@@ -66,29 +79,87 @@ def _maybe_download_corpora(tmp_dir):
6679
tmp_dir, dailymail_filename, _DAILYMAIL_STORIES_DRIVE_URL)
6780
with tarfile.open(dailymail_file, "r:gz") as dailymail_tar:
6881
dailymail_tar.extractall(tmp_dir)
69-
return [cnn_finalpath, dailymail_finalpath]
70-
71-
72-
def story_generator(tmp_dir):
73-
paths = _maybe_download_corpora(tmp_dir)
74-
for path in paths:
75-
for story_file in tf.gfile.Glob(path + "*"):
76-
story = u""
77-
for line in tf.gfile.Open(story_file, "rb"):
78-
line = unicode(line, "utf-8") if six.PY2 else line.decode("utf-8")
79-
story += line
80-
yield story
8182

83+
cnn_files = tf.gfile.Glob(cnn_finalpath + "*")
84+
dailymail_files = tf.gfile.Glob(dailymail_finalpath + "*")
85+
all_files = cnn_files + dailymail_files
86+
87+
if is_training:
88+
urls_path = generator_utils.maybe_download(tmp_dir, "all_train.txt", _TRAIN_URLS)
89+
else:
90+
urls_path = generator_utils.maybe_download(tmp_dir, "all_val.txt", _DEV_URLS)
91+
92+
return all_files, urls_path
93+
94+
def example_splits(url_file, all_files):
95+
def generate_hash(inp):
96+
"""Generate a sha1 hash to match the raw url to the filename extracted"""
97+
h = hashlib.sha1()
98+
h.update(inp)
99+
return h.hexdigest()
100+
101+
all_files_map = {f.split("/")[-1]:f for f in all_files}
102+
103+
urls = []
104+
for line in tf.gfile.Open(url_file):
105+
urls.append(line.strip().encode('utf-8'))
106+
107+
filelist = []
108+
for url in urls:
109+
url_hash = generate_hash(url)
110+
filename = url_hash + ".story"
111+
if filename not in all_files_map:
112+
tf.logging.info("Missing file: %s" % url)
113+
continue
114+
filelist.append(all_files_map[filename])
115+
116+
tf.logging.info("Found %d examples" % len(filelist))
117+
118+
return filelist
119+
120+
def example_generator(tmp_dir, is_training, sum_token):
121+
def fix_run_on_sents(line):
122+
if u"@highlight" in line: return line
123+
if line=="": return line
124+
if line[-1] in END_TOKENS: return line
125+
return line + u"."
126+
127+
all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
128+
filelist = example_splits(urls_path, all_files)
129+
story_summary_split_token = u" <summary> " if sum_token else " "
130+
131+
for story_file in filelist:
132+
story = []
133+
summary = []
134+
reading_highlights = False
135+
for line in tf.gfile.Open(story_file, "rb"):
136+
line = unicode(line.strip(), "utf-8") if six.PY2 else line.strip().decode("utf-8")
137+
line = fix_run_on_sents(line)
138+
if line == "":
139+
continue
140+
elif line.startswith(u"@highlight"):
141+
if len(story) == 0: break # No article text
142+
reading_highlights = True
143+
elif reading_highlights:
144+
summary.append(line)
145+
else:
146+
story.append(line)
147+
148+
if len(story) == 0 or len(summary) == 0:
149+
continue
150+
151+
yield " ".join(story) + story_summary_split_token + " ".join(summary)
82152

83153
def _story_summary_split(story):
84-
end_pos = story.find("\n\n") # Upto first empty line.
85-
assert end_pos != -1
86-
return story[:end_pos], story[end_pos:].strip()
154+
split_str = u" <summary> "
155+
split_str_len = len(split_str)
156+
split_pos = story.find(split_str)
157+
return story[:split_pos], story[split_pos+split_str_len:] # story, summary
87158

88159

89160
@registry.register_problem
90161
class SummarizeCnnDailymail32k(problem.Text2TextProblem):
91-
"""Summarize CNN and Daily Mail articles to their first paragraph."""
162+
"""Summarize CNN and Daily Mail articles to their summary highlights."""
92163

93164
@property
94165
def is_character_level(self):
@@ -124,14 +195,14 @@ def targeted_vocab_size(self):
124195

125196
@property
126197
def use_train_shards_for_dev(self):
127-
return True
198+
return False
128199

129-
def generator(self, data_dir, tmp_dir, _):
200+
def generator(self, data_dir, tmp_dir, is_training):
130201
encoder = generator_utils.get_or_generate_vocab_inner(
131202
data_dir, self.vocab_file, self.targeted_vocab_size,
132-
story_generator(tmp_dir))
133-
for story in story_generator(tmp_dir):
134-
summary, rest = _story_summary_split(story)
203+
example_generator(tmp_dir, is_training, sum_token=False))
204+
for example in example_generator(tmp_dir, is_training, sum_token=True):
205+
story, summary = _story_summary_split(example)
135206
encoded_summary = encoder.encode(summary) + [EOS]
136-
encoded_story = encoder.encode(rest) + [EOS]
207+
encoded_story = encoder.encode(story) + [EOS]
137208
yield {"inputs": encoded_story, "targets": encoded_summary}

0 commit comments

Comments
 (0)