21
21
22
22
import os
23
23
import tarfile
24
+ import hashlib
24
25
25
26
# Dependency imports
26
27
38
39
39
40
_DAILYMAIL_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs"
40
41
42
+ # Note: using See et al. (2017) as reference for data generation
43
+ # For more info, use the links below
44
+
45
+ # Train/Dev/Test Splits for summarization data
46
+ _TRAIN_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt"
47
+ _DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt"
48
+ _TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt"
41
49
42
50
# End-of-sentence marker.
43
51
EOS = text_encoder .EOS_ID
44
52
53
+ # Techniques for data prep from See et al. (2017)
54
+ dm_single_close_quote = u'\u2019 ' # unicode
55
+ dm_double_close_quote = u'\u201d '
56
+ END_TOKENS = [u'.' , u'!' , u'?' , u'...' , u"'" , u"`" , u'"' , dm_single_close_quote , dm_double_close_quote , u")" ] # acceptable ways to end a sentence
57
+
45
58
46
- def _maybe_download_corpora (tmp_dir ):
59
+ def _maybe_download_corpora (tmp_dir , is_training ):
47
60
"""Download corpora if necessary and unzip them.
48
61
49
62
Args:
50
63
tmp_dir: directory containing dataset.
51
64
52
65
Returns:
53
- filepath of the downloaded corpus file.
66
+ list of all files generated and path to file containing train/dev/test split info .
54
67
"""
55
68
cnn_filename = "cnn_stories.tgz"
56
69
cnn_finalpath = os .path .join (tmp_dir , "cnn/stories/" )
@@ -66,29 +79,87 @@ def _maybe_download_corpora(tmp_dir):
66
79
tmp_dir , dailymail_filename , _DAILYMAIL_STORIES_DRIVE_URL )
67
80
with tarfile .open (dailymail_file , "r:gz" ) as dailymail_tar :
68
81
dailymail_tar .extractall (tmp_dir )
69
- return [cnn_finalpath , dailymail_finalpath ]
70
-
71
-
72
- def story_generator (tmp_dir ):
73
- paths = _maybe_download_corpora (tmp_dir )
74
- for path in paths :
75
- for story_file in tf .gfile .Glob (path + "*" ):
76
- story = u""
77
- for line in tf .gfile .Open (story_file , "rb" ):
78
- line = unicode (line , "utf-8" ) if six .PY2 else line .decode ("utf-8" )
79
- story += line
80
- yield story
81
82
83
+ cnn_files = tf .gfile .Glob (cnn_finalpath + "*" )
84
+ dailymail_files = tf .gfile .Glob (dailymail_finalpath + "*" )
85
+ all_files = cnn_files + dailymail_files
86
+
87
+ if is_training :
88
+ urls_path = generator_utils .maybe_download (tmp_dir , "all_train.txt" , _TRAIN_URLS )
89
+ else :
90
+ urls_path = generator_utils .maybe_download (tmp_dir , "all_val.txt" , _DEV_URLS )
91
+
92
+ return all_files , urls_path
93
+
94
+ def example_splits (url_file , all_files ):
95
+ def generate_hash (inp ):
96
+ """Generate a sha1 hash to match the raw url to the filename extracted"""
97
+ h = hashlib .sha1 ()
98
+ h .update (inp )
99
+ return h .hexdigest ()
100
+
101
+ all_files_map = {f .split ("/" )[- 1 ]:f for f in all_files }
102
+
103
+ urls = []
104
+ for line in tf .gfile .Open (url_file ):
105
+ urls .append (line .strip ().encode ('utf-8' ))
106
+
107
+ filelist = []
108
+ for url in urls :
109
+ url_hash = generate_hash (url )
110
+ filename = url_hash + ".story"
111
+ if filename not in all_files_map :
112
+ tf .logging .info ("Missing file: %s" % url )
113
+ continue
114
+ filelist .append (all_files_map [filename ])
115
+
116
+ tf .logging .info ("Found %d examples" % len (filelist ))
117
+
118
+ return filelist
119
+
120
+ def example_generator (tmp_dir , is_training , sum_token ):
121
+ def fix_run_on_sents (line ):
122
+ if u"@highlight" in line : return line
123
+ if line == "" : return line
124
+ if line [- 1 ] in END_TOKENS : return line
125
+ return line + u"."
126
+
127
+ all_files , urls_path = _maybe_download_corpora (tmp_dir , is_training )
128
+ filelist = example_splits (urls_path , all_files )
129
+ story_summary_split_token = u" <summary> " if sum_token else " "
130
+
131
+ for story_file in filelist :
132
+ story = []
133
+ summary = []
134
+ reading_highlights = False
135
+ for line in tf .gfile .Open (story_file , "rb" ):
136
+ line = unicode (line .strip (), "utf-8" ) if six .PY2 else line .strip ().decode ("utf-8" )
137
+ line = fix_run_on_sents (line )
138
+ if line == "" :
139
+ continue
140
+ elif line .startswith (u"@highlight" ):
141
+ if len (story ) == 0 : break # No article text
142
+ reading_highlights = True
143
+ elif reading_highlights :
144
+ summary .append (line )
145
+ else :
146
+ story .append (line )
147
+
148
+ if len (story ) == 0 or len (summary ) == 0 :
149
+ continue
150
+
151
+ yield " " .join (story ) + story_summary_split_token + " " .join (summary )
82
152
83
153
def _story_summary_split (story ):
84
- end_pos = story .find ("\n \n " ) # Upto first empty line.
85
- assert end_pos != - 1
86
- return story [:end_pos ], story [end_pos :].strip ()
154
+ split_str = u" <summary> "
155
+ split_str_len = len (split_str )
156
+ split_pos = story .find (split_str )
157
+ return story [:split_pos ], story [split_pos + split_str_len :] # story, summary
87
158
88
159
89
160
@registry .register_problem
90
161
class SummarizeCnnDailymail32k (problem .Text2TextProblem ):
91
- """Summarize CNN and Daily Mail articles to their first paragraph ."""
162
+ """Summarize CNN and Daily Mail articles to their summary highlights ."""
92
163
93
164
@property
94
165
def is_character_level (self ):
@@ -124,14 +195,14 @@ def targeted_vocab_size(self):
124
195
125
196
@property
126
197
def use_train_shards_for_dev (self ):
127
- return True
198
+ return False
128
199
129
- def generator (self , data_dir , tmp_dir , _ ):
200
+ def generator (self , data_dir , tmp_dir , is_training ):
130
201
encoder = generator_utils .get_or_generate_vocab_inner (
131
202
data_dir , self .vocab_file , self .targeted_vocab_size ,
132
- story_generator (tmp_dir ))
133
- for story in story_generator (tmp_dir ):
134
- summary , rest = _story_summary_split (story )
203
+ example_generator (tmp_dir , is_training , sum_token = False ))
204
+ for example in example_generator (tmp_dir , is_training , sum_token = True ):
205
+ story , summary = _story_summary_split (example )
135
206
encoded_summary = encoder .encode (summary ) + [EOS ]
136
- encoded_story = encoder .encode (rest ) + [EOS ]
207
+ encoded_story = encoder .encode (story ) + [EOS ]
137
208
yield {"inputs" : encoded_story , "targets" : encoded_summary }
0 commit comments