Skip to content

Commit abeed54

Browse files
committed
Clean-up internal data api for faa_aligner
1 parent bde002e commit abeed54

File tree

6 files changed

+120
-76
lines changed

6 files changed

+120
-76
lines changed

amrlib/alignments/faa_aligner/faa_aligner.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,22 @@
1717
class FAA_Aligner(object):
1818
def __init__(self, **kwargs):
1919
self.model_dir = kwargs.get('model_dir', os.path.join(data_dir, 'model_aligner_faa'))
20-
self.working_dir = kwargs.get('working_dir', os.path.join(data_dir, 'working_faa_aligner'))
2120
self.model_tar_fn = kwargs.get('model_tar_fn', os.path.join(this_dir, 'model_aligner_faa.tar.gz'))
22-
os.makedirs(self.working_dir, exist_ok=True)
2321
self.setup_model_dir()
2422
self.aligner = TrainedAligner(self.model_dir, **kwargs)
2523
self.aligner.check_for_binaries() # Will raise FileNotFoundError if binaries can't be found
2624

25+
# Input space_tok_sents is a list of space tokenized strings
26+
# graph_strings is a list and amr graph strings, the same size.
27+
def align_sents(self, space_tok_sents, graph_strings):
28+
assert len(space_tok_sents) == len(graph_strings)
29+
graph_strings = [to_graph_line(g) for g in graph_strings]
30+
data = preprocess_infer(space_tok_sents, graph_strings)
31+
data.model_out_lines = self.aligner.align(data.eng_preproc_lines, data.amr_preproc_lines)
32+
amr_surface_aligns, alignment_strings = postprocess(data)
33+
return amr_surface_aligns, alignment_strings
34+
35+
2736
# check the model directory, if it doesn't have the metadata file try to create
2837
# the directory from the tar.gz file
2938
def setup_model_dir(self):
@@ -43,13 +52,6 @@ def setup_model_dir(self):
4352
logger.critical('No model in model_dir and no local version available to extract')
4453
return False
4554

46-
def align_sents(self, sents, gstrings):
47-
gstrings = [to_graph_line(g) for g in gstrings]
48-
eng_td_lines, amr_td_lines = preprocess_infer(self.working_dir, sents, gstrings)
49-
fa_out_lines = self.aligner.align(eng_td_lines, amr_td_lines)
50-
amr_surface_aligns, alignment_strings = postprocess(self.working_dir, fa_out_lines, sents, gstrings)
51-
return amr_surface_aligns, alignment_strings
52-
5355

5456
# Code adapted from from https://github.com/clab/fast_align/blob/master/src/force_align.py
5557
class TrainedAligner:

amrlib/alignments/faa_aligner/postprocess.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,48 +10,28 @@
1010

1111

1212
# if model_out_lines is None, read from the file
13-
def postprocess(wk_dir, model_out_lines=None, eng_lines=None, amr_lines=None, **kwargs):
14-
# Input filenames
15-
eng_fn = os.path.join(wk_dir, kwargs.get('eng_fn', 'sents.txt'))
16-
amr_fn = os.path.join(wk_dir, kwargs.get('amr_fn', 'gstrings.txt'))
17-
eng_tok_pos_fn = os.path.join(wk_dir, kwargs.get('eng_tok_pos_fn', 'eng_tok_origpos.txt'))
18-
amr_tuple_fn = os.path.join(wk_dir, kwargs.get('amr_tuple', 'amr_tuple.txt'))
19-
model_out_fn = os.path.join(wk_dir, kwargs.get('model_out_fn', 'model_out.txt'))
13+
def postprocess(data, **kwargs):
2014
# Error log
21-
align_to_str_fn = os.path.join(wk_dir, kwargs.get('align_to_str_fn', 'align_to_str.err'))
22-
23-
# Read the input files and get the number of lines, which must be the same
24-
if eng_lines is None or amr_lines is None:
25-
with open(eng_fn) as f:
26-
eng_lines = [l.strip() for l in f]
27-
with open(amr_fn) as f:
28-
amr_lines = [l.strip() for l in f]
29-
assert len(eng_lines) == len(amr_lines)
30-
lines_number = len(eng_lines)
15+
log_dir = kwargs.get('log_dir', 'logs')
16+
os.makedirs(log_dir, exist_ok=True)
17+
align_to_str_fn = os.path.join(log_dir, kwargs.get('postprocess_log_fn', 'faa_postprocess.log'))
3118

3219
# Read the output of the aligner or use the supplied input above
3320
# fast_align outputs with a dash but the code from the isi aligner is setup for spaces
34-
if model_out_lines is None:
35-
with open(model_out_fn) as f:
36-
model_out_lines = f.readlines()
37-
# fast_align outputs with dashes, giza does this without
38-
giza_align_lines = [l.strip().replace('-', ' ') for l in model_out_lines]
21+
giza_align_lines = [l.strip().replace('-', ' ') for l in data.model_out_lines]
3922
isi_align_lines = giza2isi(giza_align_lines)
40-
align_real_lines = swap(isi_align_lines)[:lines_number] # rm data added for training, past original sentences
23+
num_lines = len(data.amr_lines)
24+
align_real_lines = swap(isi_align_lines)[:num_lines] # rm data added for training, past original sentences
4125

42-
# Load the original sentence tokenization positions (created in pre-process)
43-
with open(eng_tok_pos_fn) as f:
44-
eng_tok_origpos_lines = [l.strip() for l in f]
45-
align_origpos_lines = map_ibmpos_to_origpos_amr_as_f(eng_tok_origpos_lines, align_real_lines)
26+
# Align the original position lines
27+
align_origpos_lines = map_ibmpos_to_origpos_amr_as_f(data.eng_tok_origpos_lines, align_real_lines)
4628

47-
# Load the amr tuples from the pre-process steps and add the alignments
48-
with open(amr_tuple_fn) as f:
49-
amr_tuple_lines = [l.strip() for l in f]
50-
aligned_tuple_lines = get_aligned_tuple_amr_as_f_add_align(amr_tuple_lines, align_origpos_lines)
29+
# Get the aligned tuples
30+
aligned_tuple_lines = get_aligned_tuple_amr_as_f_add_align(data.amr_tuple_lines, align_origpos_lines)
5131

5232
# Create amr graphs with surface alignments
53-
amr_surface_aligns = feat2tree.align(amr_lines, aligned_tuple_lines, log_fn=align_to_str_fn)
54-
assert len(amr_surface_aligns) == len(eng_lines)
33+
amr_surface_aligns = feat2tree.align(data.amr_lines, aligned_tuple_lines, log_fn=align_to_str_fn)
34+
assert len(amr_surface_aligns) == len(data.amr_lines)
5535

5636
# Get the final alignment string from the surface alignments
5737
ga = GetAlignments.from_amr_strings(amr_surface_aligns)

amrlib/alignments/faa_aligner/preprocess.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,61 +4,54 @@
44
from .process_utils import stem_4_letters_word, stem_4_letters_line, stem_4_letters_string
55
from .process_utils import filter_eng_by_stopwords, get_lineartok_with_rel
66
from .process_utils import get_id_mapping_uniq
7-
7+
from .proc_data import ProcData
88

99
# Set the default data data for misc files
1010
default_res_dir = os.path.dirname(os.path.realpath(__file__))
1111
default_res_dir = os.path.realpath(os.path.join(default_res_dir, 'resources'))
1212

1313

1414
# Preprocess for inference
15-
def preprocess_infer(wk_dir, eng_lines, amr_lines, **kwargs):
15+
def preprocess_infer(eng_lines, amr_lines, **kwargs):
1616
assert len(eng_lines) == len(amr_lines)
1717
# Resource filenames
1818
res_dir = kwargs.get('res_dir', default_res_dir)
1919
eng_sw_fn = kwargs.get('eng_sw_fn', os.path.join(res_dir, 'eng_stopwords.txt'))
2020
amr_sw_fn = kwargs.get('amr_sw_fn', os.path.join(res_dir, 'amr_stopwords.txt'))
21-
# Output filenames
22-
eng_tok_pos_fn = os.path.join(wk_dir, kwargs.get('eng_tok_pos_fn', 'eng_tok_origpos.txt'))
23-
amr_tuple_fn = os.path.join(wk_dir, kwargs.get('amr_tuple_fn', 'amr_tuple.txt'))
2421

2522
# Filter out stopwords from sentences
2623
eng_tok_filtered_lines, eng_tok_origpos_lines = filter_eng_by_stopwords(eng_lines, eng_sw_fn)
27-
28-
# Save for post-processing
29-
with open(eng_tok_pos_fn, 'w') as f:
30-
for i, l in enumerate(eng_tok_origpos_lines):
31-
assert l.strip(), '!!! ERROR Empty line# %d. This will cause issues and must be fixed !!!' % i
32-
f.write(l + '\n')
24+
for i, line in enumerate(eng_tok_origpos_lines):
25+
if not line.strip():
26+
raise ValueError('!!! ERROR Empty line# %d. This will cause issues and must be fixed !!!' % i)
3327

3428
# Stem sentence tokens
35-
eng_tok_stemmed_lines = [stem_4_letters_line(l) for l in eng_tok_filtered_lines]
29+
eng_preproc_lines = [stem_4_letters_line(l) for l in eng_tok_filtered_lines]
3630

3731
# Process the AMR data / remove stopwords
3832
amr_linear_lines, amr_tuple_lines = get_lineartok_with_rel(amr_lines, amr_sw_fn)
3933

40-
# Save for post-processing
41-
with open(amr_tuple_fn, 'w') as f:
42-
for l in amr_tuple_lines:
43-
f.write(l + '\n')
44-
4534
# Stem the AMR lines
46-
amr_linear_stemmed_lines = []
35+
amr_preproc_lines = []
4736
for line in amr_linear_lines:
4837
new_tokens = []
4938
for token in line.split():
5039
token = re.sub(r'\-[0-9]{2,3}$', '', token)
5140
token = token.replace('"', '')
5241
token = stem_4_letters_word(token).strip()
5342
new_tokens.append(token)
54-
amr_linear_stemmed_lines.append(' '.join(new_tokens))
43+
amr_preproc_lines.append(' '.join(new_tokens))
5544

56-
return eng_tok_stemmed_lines, amr_linear_stemmed_lines
45+
# Gather the data
46+
assert len(eng_preproc_lines) == len(amr_preproc_lines)
47+
data = ProcData(eng_lines, amr_lines, eng_tok_origpos_lines, amr_tuple_lines,
48+
eng_preproc_lines, amr_preproc_lines,)
49+
return data
5750

5851

5952
# Preprocess the training data. This is the similar to inference but add a lot of
6053
# extra translation lines from resource files, etc..
61-
def preprocess_train(wk_dir, eng_lines, amr_lines, **kwargs):
54+
def preprocess_train(eng_lines, amr_lines, **kwargs):
6255
repeat_td = kwargs.get('repeat_td', 10) # 10X is original value from isi aligner
6356
# Resource filenames
6457
res_dir = kwargs.get('res_dir', default_res_dir)
@@ -67,10 +60,12 @@ def preprocess_train(wk_dir, eng_lines, amr_lines, **kwargs):
6760
amr_id_map_fn = kwargs.get('amr_id_map_fn', os.path.join(res_dir, 'amr_id_map.txt'))
6861

6962
# Run the inference process which creates the basic translation data
70-
eng_tok_stemmed_lines, amr_linear_stemmed_lines = preprocess_infer(wk_dir, eng_lines, amr_lines, **kwargs)
63+
data = preprocess_infer(eng_lines, amr_lines, **kwargs)
64+
eng_preproc_lines = data.eng_preproc_lines
65+
amr_preproc_lines = data.amr_preproc_lines
7166

7267
# Get tokens common between the two datasets (obvious translations
73-
common_tok_lines = get_id_mapping_uniq(eng_tok_stemmed_lines, amr_linear_stemmed_lines)
68+
common_tok_lines = get_id_mapping_uniq(eng_preproc_lines, amr_preproc_lines)
7469
eng_td_lines = common_tok_lines[:] # copy
7570

7671
# Append the second field in prep-roles.id.txt
@@ -101,8 +96,8 @@ def preprocess_train(wk_dir, eng_lines, amr_lines, **kwargs):
10196

10297
# Create the final training data using the original sentences
10398
# and 10X copies of the additional data (other translations)
104-
eng_td_lines = eng_tok_stemmed_lines + [l for _ in range(repeat_td) for l in eng_td_lines]
105-
amr_td_lines = amr_linear_stemmed_lines + [l for _ in range(repeat_td) for l in amr_td_lines]
106-
assert len(eng_td_lines) == len(amr_td_lines)
99+
data.eng_preproc_lines += [l for _ in range(repeat_td) for l in eng_td_lines]
100+
data.amr_preproc_lines += [l for _ in range(repeat_td) for l in amr_td_lines]
101+
assert len(data.eng_preproc_lines) == len(data.amr_preproc_lines)
107102

108-
return eng_td_lines, amr_td_lines
103+
return data
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
3+
# Simple container for holding process data
4+
# This is a little complicated because preprocess creates data that is used by the model for input
5+
# and other data that is used during postprocessing. Keep track of it all here.
6+
# Saving and loading is to facilitate training scripts. During inference, data will be held internally.
7+
class ProcData(object):
8+
def __init__(self, eng_lines=None, amr_lines=None,
9+
eng_tok_origpos_lines=None, amr_tuple_lines=None,
10+
eng_preproc_lines=None, amr_preproc_lines=None):
11+
self.eng_lines = eng_lines
12+
self.amr_lines = amr_lines
13+
self.eng_tok_origpos_lines = eng_tok_origpos_lines
14+
self.amr_tuple_lines = amr_tuple_lines
15+
self.eng_preproc_lines = eng_preproc_lines
16+
self.amr_preproc_lines = amr_preproc_lines
17+
18+
# Save the preprocess and model input data (optionally the original x_lines data)
19+
def save(self, wk_dir, save_input_data=False, **kwargs):
20+
self.build_filenames(wk_dir, **kwargs)
21+
if save_input_data:
22+
self.save_lines(self.eng_fn, self.eng_lines)
23+
self.save_lines(self.amr_fn, self.amr_lines)
24+
self.save_lines(self.eng_tok_pos_fn, self.eng_tok_origpos_lines)
25+
self.save_lines(self.amr_tuple_fn, self.amr_tuple_lines)
26+
with open(self.fa_in_fn, 'w') as f:
27+
for en_line, amr_line in zip(self.eng_preproc_lines, self.amr_preproc_lines):
28+
f.write('%s ||| %s\n' % (en_line, amr_line))
29+
30+
# load data (not including the _preproc_lines)
31+
@classmethod
32+
def from_directory(cls, wk_dir, **kwargs):
33+
self = cls()
34+
self.build_filenames(wk_dir, **kwargs)
35+
self.eng_lines = self.load_lines(self.eng_fn)
36+
self.amr_lines = self.load_lines(self.amr_fn)
37+
self.eng_tok_origpos_lines = self.load_lines(self.eng_tok_pos_fn)
38+
self.amr_tuple_lines = self.load_lines(self.amr_tuple_fn)
39+
self.model_out_lines = self.load_lines(self.model_out_fn)
40+
return self
41+
42+
# Create default filenames as members
43+
def build_filenames(self, wk_dir, **kwargs):
44+
self.eng_fn = os.path.join(wk_dir, kwargs.get('eng_fn', 'sents.txt'))
45+
self.amr_fn = os.path.join(wk_dir, kwargs.get('eng_fn', 'gstrings.txt'))
46+
self.eng_tok_pos_fn = os.path.join(wk_dir, kwargs.get('eng_tok_pos_fn', 'eng_tok_origpos.txt'))
47+
self.amr_tuple_fn = os.path.join(wk_dir, kwargs.get('amr_tuple_fn', 'amr_tuple.txt'))
48+
self.fa_in_fn = os.path.join(wk_dir, kwargs.get('fa_in_fn', 'fa_in.txt'))
49+
self.model_out_fn = os.path.join(wk_dir, kwargs.get('model_out_fn', 'model_out.txt'))
50+
51+
# Save a list of lines to a file
52+
@staticmethod
53+
def save_lines(fn, lines):
54+
with open(fn, 'w') as f:
55+
for line in lines:
56+
f.write(line + '\n')
57+
58+
# Load a list of lines from a file
59+
@staticmethod
60+
def load_lines(fn):
61+
with open(fn) as f:
62+
lines = [l.strip() for l in f]
63+
return lines

scripts/61_FAA_Aligner/12_Preprocess_train.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
working_dir = 'amrlib/data/train_faa_aligner'
99
eng_fn = os.path.join(working_dir, 'sents.txt')
1010
amr_fn = os.path.join(working_dir, 'gstrings.txt')
11-
fa_in_fn = os.path.join(working_dir, 'fa_in.txt')
1211

1312
print('Reading and writing data in', working_dir)
1413
# Read in the english sentences and linearized AMR lines
@@ -18,9 +17,8 @@
1817
amr_lines = [l.strip().lower() for l in f]
1918

2019
# Proprocess the data
21-
eng_td_lines, amr_td_lines = preprocess_train(working_dir, eng_lines, amr_lines)
20+
data = preprocess_train(eng_lines, amr_lines)
2221

23-
# Save in fast align training format
24-
with open(fa_in_fn, 'w') as f:
25-
for en_line, amr_line in zip(eng_td_lines, amr_td_lines):
26-
f.write('%s ||| %s\n' % (en_line, amr_line))
22+
# Save the preprocess data and the model input file, the input data
23+
# already in the working directory
24+
data.save(working_dir, save_input_data=False)

scripts/61_FAA_Aligner/16_PostProcess.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import setup_run_dir # this import tricks script to run from 2 levels up
33
import os
44
from amrlib.alignments.faa_aligner.postprocess import postprocess
5+
from amrlib.alignments.faa_aligner.proc_data import ProcData
56

67

78
if __name__ == '__main__':
@@ -10,7 +11,12 @@
1011
surface_fn = 'amr_surface_aligned.txt'
1112

1213
print('Reading and writing data in', working_dir)
13-
amr_surface_aligns, alignment_strings = postprocess(working_dir)
14+
15+
# Load the original, preprocess and model output data
16+
data = ProcData.from_directory(working_dir)
17+
18+
# Post process
19+
amr_surface_aligns, alignment_strings = postprocess(data)
1420

1521
# Save the final data
1622
fpath = os.path.join(working_dir, astrings_fn)

0 commit comments

Comments
 (0)