forked from OpenNMT/OpenNMT-py
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add terminology transform (OpenNMT#2442)
* Add terminology transform
- Loading branch information
Showing
4 changed files
with
370 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,348 @@ | ||
from onmt.utils.logging import logger | ||
from onmt.transforms import register_transform | ||
from .transform import Transform | ||
|
||
import spacy | ||
import ahocorasick | ||
import re | ||
|
||
|
||
class TermMatcher(object): | ||
def __init__( | ||
self, | ||
termbase_path, | ||
src_spacy_language_model, | ||
tgt_spacy_language_model, | ||
term_example_ratio, | ||
src_term_stoken, | ||
tgt_term_stoken, | ||
tgt_term_etoken, | ||
delimiter, | ||
term_corpus_ratio=0.2, | ||
): | ||
self.term_example_ratio = term_example_ratio | ||
self.src_nlp = spacy.load(src_spacy_language_model, disable=["parser", "ner"]) | ||
self.tgt_nlp = spacy.load(tgt_spacy_language_model, disable=["parser", "ner"]) | ||
|
||
# We exclude tokenization for contractions in | ||
# order to avoid inconsistencies with pyonmtok's tokenization. | ||
# (e.g. "I ca n't" with spacy, "I can ' t" with pyonmttok) | ||
self.src_nlp.tokenizer.rules = { | ||
key: value | ||
for key, value in self.src_nlp.tokenizer.rules.items() | ||
if "'" not in key and "’" not in key and "‘" not in key | ||
} | ||
self.tgt_nlp.tokenizer.rules = { | ||
key: value | ||
for key, value in self.tgt_nlp.tokenizer.rules.items() | ||
if "'" not in key and "’" not in key and "‘" not in key | ||
} | ||
self.internal_termbase = self._create_internal_termbase(termbase_path) | ||
self.automaton = self._create_automaton() | ||
self.term_corpus_ratio = term_corpus_ratio | ||
self.src_term_stoken = src_term_stoken | ||
self.tgt_term_stoken = tgt_term_stoken | ||
self.tgt_term_etoken = tgt_term_etoken | ||
self.delimiter = delimiter | ||
|
||
def _create_internal_termbase(self, termbase_path): | ||
logger.debug("Creating termbase with lemmas for Terminology transform") | ||
|
||
# Use Spacy's stopwords to get rid of junk entries | ||
src_stopwords = self.src_nlp.Defaults.stop_words | ||
tgt_stopwords = self.tgt_nlp.Defaults.stop_words | ||
termbase = list() | ||
with open(termbase_path, mode="r", encoding="utf-8") as file: | ||
pairs = file.readlines() | ||
for pair in pairs: | ||
src_term, tgt_term = map(str, pair.split("\t")) | ||
src_lemma = " ".join( | ||
"∥".join(tok.lemma_.split()) for tok in self.src_nlp(src_term) | ||
).strip() | ||
tgt_lemma = " ".join( | ||
tok.lemma_ for tok in self.tgt_nlp(tgt_term) | ||
).strip() | ||
if ( | ||
src_lemma.lower() not in src_stopwords | ||
and tgt_lemma.lower() not in tgt_stopwords | ||
): | ||
termbase.append((src_lemma, tgt_lemma)) | ||
logger.debug( | ||
f"Created termbase with {len(termbase)} lemmas " | ||
f"for Terminology transform" | ||
) | ||
return termbase | ||
|
||
def _create_automaton(self): | ||
automaton = ahocorasick.Automaton() | ||
for term in self.internal_termbase: | ||
automaton.add_word(term[0], (term[0], term[1])) | ||
automaton.make_automaton() | ||
return automaton | ||
|
||
def _src_sentence_with_terms(self, source_string, target_string) -> tuple: | ||
|
||
maybe_augmented = source_string.split(self.delimiter) | ||
source_only = maybe_augmented[0].strip() | ||
augmented_part = ( | ||
maybe_augmented[1].strip() if len(maybe_augmented) > 1 else None | ||
) | ||
|
||
doc_src = self.src_nlp(source_only) | ||
doc_tgt = self.tgt_nlp(target_string) | ||
|
||
# Perform tokenization with spacy for consistency. | ||
tokenized_source = [tok.text for tok in doc_src] | ||
lemmatized_source = ["∥".join(tok.lemma_.lower().split()) for tok in doc_src] | ||
lemmatized_target = [tok.lemma_.lower() for tok in doc_tgt] | ||
|
||
lemmatized_source_string = " ".join(lemmatized_source) | ||
|
||
offset = 0 | ||
source_with_terms = list() | ||
term_counter = 0 | ||
|
||
max_terms_allowed = int(len(tokenized_source) * self.term_example_ratio) | ||
is_match = False | ||
for match_end, (src_entry, tgt_entry) in self.automaton.iter_long( | ||
lemmatized_source_string | ||
): | ||
|
||
if term_counter == max_terms_allowed: | ||
break | ||
|
||
match_start = match_end - len(src_entry) + 1 | ||
|
||
# We ensure that the target lemma is present in the lemmatized | ||
# target string, that the match is an exact match (there is | ||
# whitespace before or after the term) | ||
# and we perform some bound checking. | ||
if ( | ||
(tgt_entry.lower() not in " ".join(lemmatized_target).lower()) | ||
or ( | ||
len(lemmatized_source_string) != match_end + 1 | ||
and not (lemmatized_source_string[match_end + 1].isspace()) | ||
) | ||
or ( | ||
not lemmatized_source_string[match_start - 1].isspace() | ||
and match_start != 0 | ||
) | ||
): | ||
continue | ||
else: | ||
term_counter += 1 | ||
|
||
# Map the lemmatized string match index to | ||
# the lemmatized list index | ||
lemma_list_index = 0 | ||
for i, w in enumerate(lemmatized_source): | ||
if lemma_list_index == match_start: | ||
lemma_list_index = i | ||
break | ||
else: | ||
lemma_list_index += len(w) + 1 | ||
|
||
# We need to know if the term is multiword | ||
num_words_in_src_term = len(src_entry.split()) | ||
src_term = " ".join( | ||
tokenized_source[ | ||
lemma_list_index : lemma_list_index + num_words_in_src_term | ||
] | ||
).strip() | ||
|
||
# Join multiword target lemmas with a unique separator so | ||
# we can treat them as single word and not change the indices. | ||
tgt_term = tgt_entry.replace(" ", "∥").rstrip().lower() | ||
source_with_terms.append( | ||
f"{lemmatized_source_string[offset: match_start]}" | ||
f"{self.src_term_stoken}∥{src_term}∥{self.tgt_term_stoken}∥" | ||
f"{tgt_term}∥{self.tgt_term_etoken}" | ||
) | ||
|
||
offset = match_end + 1 | ||
is_match = True | ||
|
||
if is_match: | ||
source_with_terms.append(lemmatized_source_string[offset:]) | ||
tokenized_source_with_terms = "".join(source_with_terms).split() | ||
|
||
if not ( | ||
len(tokenized_source) | ||
== len(lemmatized_source) | ||
== len(tokenized_source_with_terms) | ||
): | ||
final_string = " ".join(tokenized_source) | ||
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) | ||
return fixed_punct.split(), not is_match | ||
|
||
# Construct the final source from the lemmatized list | ||
# that contains the terms. We compare the tokens in the | ||
# term-augmented lemma list with the tokens in the original | ||
# lemma list. If the lemma is the same, then we replace with | ||
# the token from the original tokenized source list. If they | ||
# are not the same, it means the lemma has been augemented | ||
# with a term, so we inject this in the final list. | ||
completed_tokenized_source = list() | ||
for idx in range(len(tokenized_source_with_terms)): | ||
# Restore the spaces in multi-word terms | ||
src_lemma = tokenized_source_with_terms[idx].replace("∥", " ") | ||
if lemmatized_source[idx].replace("∥", " ") == src_lemma: | ||
completed_tokenized_source.append(tokenized_source[idx]) | ||
else: | ||
completed_tokenized_source.append(src_lemma) | ||
|
||
if augmented_part is not None: | ||
final_string = " ".join( | ||
completed_tokenized_source | ||
+ [self.delimiter] | ||
+ augmented_part.split() | ||
) | ||
else: | ||
final_string = " ".join(completed_tokenized_source) | ||
|
||
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) | ||
return fixed_punct.split(), is_match | ||
else: | ||
final_string = " ".join(tokenized_source) | ||
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) | ||
return fixed_punct.split(), not is_match | ||
|
||
|
||
@register_transform(name="terminology") | ||
class TerminologyTransform(Transform): | ||
def __init__(self, opts): | ||
super().__init__(opts) | ||
|
||
@classmethod | ||
def add_options(cls, parser): | ||
"""Available options for terminology matching.""" | ||
|
||
group = parser.add_argument_group("Transform/Terminology") | ||
group.add( | ||
"--termbase_path", | ||
"-termbase_path", | ||
type=str, | ||
help="Path to a dictionary file with terms.", | ||
) | ||
group.add( | ||
"--src_spacy_language_model", | ||
"-src_spacy_language_model", | ||
type=str, | ||
help="Name of the spacy language model for the source corpus.", | ||
) | ||
group.add( | ||
"--tgt_spacy_language_model", | ||
"-tgt_spacy_language_model", | ||
type=str, | ||
help="Name of the spacy language model for the target corpus.", | ||
) | ||
group.add( | ||
"--term_corpus_ratio", | ||
"-term_corpus_ratio", | ||
type=float, | ||
default=0.3, | ||
help="Ratio of corpus to augment with terms.", | ||
) | ||
group.add( | ||
"--term_example_ratio", | ||
"-term_example_ratio", | ||
type=float, | ||
default=0.2, | ||
help="Max terms allowed in an example.", | ||
) | ||
group.add( | ||
"--src_term_stoken", | ||
"-src_term_stoken", | ||
type=str, | ||
help="The source term start token.", | ||
default="⦅src_term_start⦆", | ||
) | ||
group.add( | ||
"--tgt_term_stoken", | ||
"-tgt_term_stoken", | ||
type=str, | ||
help="The target term start token.", | ||
default="⦅tgt_term_start⦆", | ||
) | ||
group.add( | ||
"--tgt_term_etoken", | ||
"-tgt_term_etoken", | ||
type=str, | ||
help="The target term end token.", | ||
default="⦅tgt_term_end⦆", | ||
) | ||
group.add( | ||
"--term_source_delimiter", | ||
"-term_source_delimiter", | ||
type=str, | ||
help="Any special token used for augmented source sentences. " | ||
"The default is the fuzzy token used in the " | ||
"FuzzyMatch transform.", | ||
default="⦅fuzzy⦆", | ||
) | ||
|
||
def _parse_opts(self): | ||
self.termbase_path = self.opts.termbase_path | ||
self.src_spacy_language_model = self.opts.src_spacy_language_model | ||
self.tgt_spacy_language_model = self.opts.tgt_spacy_language_model | ||
self.term_corpus_ratio = self.opts.term_corpus_ratio | ||
self.term_example_ratio = self.opts.term_example_ratio | ||
self.term_source_delimiter = self.opts.term_source_delimiter | ||
self.src_term_stoken = self.opts.src_term_stoken | ||
self.tgt_term_stoken = self.opts.tgt_term_stoken | ||
self.tgt_term_etoken = self.opts.tgt_term_etoken | ||
|
||
@classmethod | ||
def get_specials(cls, opts): | ||
"""Add the term tokens to the src vocab.""" | ||
src_specials = list() | ||
src_specials.extend( | ||
[opts.src_term_stoken, opts.tgt_term_stoken, opts.tgt_term_etoken] | ||
) | ||
return (src_specials, list()) | ||
|
||
def warm_up(self, vocabs=None): | ||
"""Create the terminology matcher.""" | ||
|
||
super().warm_up(None) | ||
self.termmatcher = TermMatcher( | ||
self.termbase_path, | ||
self.src_spacy_language_model, | ||
self.tgt_spacy_language_model, | ||
self.term_example_ratio, | ||
self.src_term_stoken, | ||
self.tgt_term_stoken, | ||
self.tgt_term_etoken, | ||
self.term_source_delimiter, | ||
self.term_corpus_ratio, | ||
) | ||
|
||
def batch_apply(self, batch, is_train=False, stats=None, **kwargs): | ||
bucket_size = len(batch) | ||
examples_with_terms = 0 | ||
|
||
for i, (ex, _, _) in enumerate(batch): | ||
# Skip half examples to improve performance. This means we set | ||
# a hard limit for the `term_corpus_ratio` to 0.5, which is actually | ||
# quite high. TODO: We can add this (skipping examples) as an option | ||
if i % 2 == 0: | ||
original_src = ex["src"] | ||
augmented_example, is_match = self.apply(ex, is_train, stats, **kwargs) | ||
if is_match and ( | ||
examples_with_terms < bucket_size * self.term_corpus_ratio | ||
): | ||
examples_with_terms += 1 | ||
ex["src"] = augmented_example["src"] | ||
else: | ||
ex["src"] = original_src | ||
|
||
logger.debug(f"Added terms to {examples_with_terms}/{bucket_size} examples") | ||
return batch | ||
|
||
def apply(self, example, is_train=False, stats=None, **kwargs) -> tuple: | ||
"""Add terms to source examples.""" | ||
|
||
example["src"], is_match = self.termmatcher._src_sentence_with_terms( | ||
" ".join(example["src"]), " ".join(example["tgt"]) | ||
) | ||
return example, is_match |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ rapidfuzz | |
scipy | ||
bitsandbytes>=0.39.0 | ||
safetensors | ||
spacy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
"rapidfuzz", | ||
"pyahocorasick", | ||
"fasttext-wheel", | ||
"spacy", | ||
], | ||
entry_points={ | ||
"console_scripts": [ | ||
|