diff --git a/README.md b/README.md index e07ac2a64..9aec1dfb0 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ CTranslate2 is a fast and full-featured inference engine for Transformer models. * [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py) * [OpenNMT-tf](https://github.com/OpenNMT/OpenNMT-tf) * [Fairseq](https://github.com/pytorch/fairseq/) +* [Marian](https://github.com/marian-nmt/marian) The project is production-oriented and comes with [backward compatibility guarantees](#what-is-the-state-of-this-project), but it also includes experimental features related to model compression and inference acceleration. @@ -66,7 +67,7 @@ pip install --upgrade pip pip install ctranslate2 ``` -**2\. [Convert](#converting-models) a Transformer model trained with OpenNMT-py, OpenNMT-tf, or Fairseq:** +**2\. [Convert](#converting-models) a Transformer model trained with OpenNMT-py, OpenNMT-tf, Fairseq, or Marian:** *a. OpenNMT-py* @@ -106,6 +107,20 @@ ct2-fairseq-converter --model_path wmt16.en-de.joined-dict.transformer/model.pt --output_dir ende_ctranslate2 ``` +*d. Marian* + +```bash +wget https://object.pouta.csc.fi/OPUS-MT-models/en-de/opus-2020-02-26.zip +unzip opus-2020-02-26.zip + +ct2-marian-converter --model_path opus.spm32k-spm32k.transformer-align.model1.npz.best-perplexity.npz \ + --vocab_paths opus.spm32k-spm32k.vocab.yml opus.spm32k-spm32k.vocab.yml \ + --output_dir ende_ctranslate2 + +# For OPUS-MT models, you can use ct2-opus-mt-converter instead: +ct2-opus-mt-converter --model_dir . --output_dir ende_ctranslate2 +``` + **3\. [Translate](#translating) tokenized inputs with the Python API:** ```python @@ -113,11 +128,11 @@ import ctranslate2 translator = ctranslate2.Translator("ende_ctranslate2/", device="cpu") -# The OpenNMT-py and OpenNMT-tf models use a SentencePiece tokenization: -translator.translate_batch([["▁H", "ello", "▁world", "!"]]) +batch = [["▁H", "ello", "▁world", "!"]] # OpenNMT model input +# batch = [["H@@", "ello", "world@@", "!"]] # Fairseq model input +# batch = [["▁Hello", "▁world", "!"]] # Marian model input -# The Fairseq model uses a BPE tokenization: -translator.translate_batch([["H@@", "ello", "world@@", "!"]]) +translator.translate_batch(batch) ``` ## Installation @@ -163,10 +178,10 @@ The core CTranslate2 implementation is framework agnostic. The framework specifi The following frameworks and models are currently supported: -| | OpenNMT-tf | OpenNMT-py | Fairseq | -| --- | :---: | :---: | :---: | -| Transformer ([Vaswani et al. 2017](https://arxiv.org/abs/1706.03762)) | ✓ | ✓ | ✓ | -| + relative position representations ([Shaw et al. 2018](https://arxiv.org/abs/1803.02155)) | ✓ | ✓ | | +| | OpenNMT-tf | OpenNMT-py | Fairseq | Marian | +| --- | :---: | :---: | :---: | :---: | +| Transformer ([Vaswani et al. 2017](https://arxiv.org/abs/1706.03762)) | ✓ | ✓ | ✓ | ✓ | +| + relative position representations ([Shaw et al. 2018](https://arxiv.org/abs/1803.02155)) | ✓ | ✓ | | | *If you are using a model that is not listed above, consider opening an issue to discuss future integration.* @@ -175,6 +190,8 @@ The Python package includes a [conversion API](docs/python.md#model-conversion-a * `ct2-opennmt-py-converter` * `ct2-opennmt-tf-converter` * `ct2-fairseq-converter` +* `ct2-marian-converter` +* `ct2-opus-mt-converter` (based on `ct2-marian-converter`) The conversion should be run in the same environment as the selected training framework. @@ -480,8 +497,10 @@ The implementation has been generously tested in [production environment](https: * Python symbols: * `ctranslate2.Translator` * `ctranslate2.converters.FairseqConverter` + * `ctranslate2.converters.MarianConverter` * `ctranslate2.converters.OpenNMTPyConverter` * `ctranslate2.converters.OpenNMTTFConverter` + * `ctranslate2.converters.OpusMTConverter` * C++ symbols: * `ctranslate2::models::Model` * `ctranslate2::TranslationOptions` diff --git a/docs/python.md b/docs/python.md index 01cf04545..3cd472793 100644 --- a/docs/python.md +++ b/docs/python.md @@ -27,6 +27,15 @@ converter = ctranslate2.converters.FairseqConverter( fixed_dictionary: str = None, # Path to the fixed dictionary for multilingual models. ) +converter = ctranslate2.converters.MarianConverter( + model_path: str, # Path to the Marian model (.npz file). + vocab_paths: List[str], # Paths to the vocabularies (.yml files). +) + +converter = ctranslate2.converters.OpusMTConverter( + model_dir: str, # Path to the OPUS-MT model directory. +) + output_dir = converter.convert( output_dir: str, # Path to the output directory. vmap: str = None, # Path to a vocabulary mapping file. diff --git a/include/ctranslate2/layers/decoder.h b/include/ctranslate2/layers/decoder.h index 1f7d81bc0..6d0e91cdf 100644 --- a/include/ctranslate2/layers/decoder.h +++ b/include/ctranslate2/layers/decoder.h @@ -11,6 +11,8 @@ namespace ctranslate2 { using DecoderState = std::unordered_map; + void zero_first_timestep(StorageView& x, dim_t step); + // Base class for decoders. class Decoder : public Layer { public: diff --git a/include/ctranslate2/layers/transformer.h b/include/ctranslate2/layers/transformer.h index 26b218e58..4a98508a3 100644 --- a/include/ctranslate2/layers/transformer.h +++ b/include/ctranslate2/layers/transformer.h @@ -191,6 +191,7 @@ namespace ctranslate2 { dim_t _alignment_heads; const ComputeType _compute_type; const Embeddings _embeddings; + const bool _start_from_zero_embedding; const std::unique_ptr _embeddings_scale; const std::unique_ptr _position_encoder; const std::unique_ptr _layernorm_embedding; diff --git a/python/ctranslate2/converters/__init__.py b/python/ctranslate2/converters/__init__.py index 6f86a7afc..668c1b0ff 100644 --- a/python/ctranslate2/converters/__init__.py +++ b/python/ctranslate2/converters/__init__.py @@ -1,4 +1,6 @@ from ctranslate2.converters.converter import Converter from ctranslate2.converters.fairseq import FairseqConverter +from ctranslate2.converters.marian import MarianConverter from ctranslate2.converters.opennmt_py import OpenNMTPyConverter from ctranslate2.converters.opennmt_tf import OpenNMTTFConverter +from ctranslate2.converters.opus_mt import OpusMTConverter diff --git a/python/ctranslate2/converters/fairseq.py b/python/ctranslate2/converters/fairseq.py index 474e8d7ec..f760bfb30 100644 --- a/python/ctranslate2/converters/fairseq.py +++ b/python/ctranslate2/converters/fairseq.py @@ -32,34 +32,36 @@ def _get_model_spec(args): activation_fn = getattr(args, "activation_fn", "relu") - reasons = [] - if args.arch not in _SUPPORTED_ARCHS: - reasons.append( - "Option --arch %s is not supported (supported architectures are: %s)" - % (args.arch, ", ".join(_SUPPORTED_ARCHS)) - ) - if args.encoder_normalize_before != args.decoder_normalize_before: - reasons.append( - "Options --encoder-normalize-before and --decoder-normalize-before " - "must have the same value" - ) - if args.encoder_attention_heads != args.decoder_attention_heads: - reasons.append( - "Options --encoder-attention-heads and --decoder-attention-heads must " - "have the same value" - ) - if activation_fn not in _SUPPORTED_ACTIVATIONS.keys(): - reasons.append( - "Option --activation-fn %s is not supported (supported activations are: %s)" - % (activation_fn, ", ".join(_SUPPORTED_ACTIVATIONS.keys())) - ) - if getattr(args, "no_token_positional_embeddings", False): - reasons.append("Option --no-token-positional-embeddings is not supported") - if getattr(args, "lang_tok_replacing_bos_eos", False): - reasons.append("Option --lang-tok-replacing-bos-eos is not supported") - - if reasons: - utils.raise_unsupported(reasons) + check = utils.ConfigurationChecker() + check( + args.arch in _SUPPORTED_ARCHS, + "Option --arch %s is not supported (supported architectures are: %s)" + % (args.arch, ", ".join(_SUPPORTED_ARCHS)), + ) + check( + args.encoder_normalize_before == args.decoder_normalize_before, + "Options --encoder-normalize-before and --decoder-normalize-before " + "must have the same value", + ) + check( + args.encoder_attention_heads == args.decoder_attention_heads, + "Options --encoder-attention-heads and --decoder-attention-heads " + "must have the same value", + ) + check( + activation_fn in _SUPPORTED_ACTIVATIONS, + "Option --activation-fn %s is not supported (supported activations are: %s)" + % (activation_fn, ", ".join(_SUPPORTED_ACTIVATIONS.keys())), + ) + check( + not getattr(args, "no_token_positional_embeddings", False), + "Option --no-token-positional-embeddings is not supported", + ) + check( + not getattr(args, "lang_tok_replacing_bos_eos", False), + "Option --lang-tok-replacing-bos-eos is not supported", + ) + check.validate() return transformer_spec.TransformerSpec( (args.encoder_layers, args.decoder_layers), diff --git a/python/ctranslate2/converters/marian.py b/python/ctranslate2/converters/marian.py new file mode 100644 index 000000000..fef61fe1a --- /dev/null +++ b/python/ctranslate2/converters/marian.py @@ -0,0 +1,267 @@ +import argparse +import re + +import numpy as np +import yaml + +from ctranslate2.converters import utils +from ctranslate2.converters.converter import Converter +from ctranslate2.specs import common_spec, model_spec, transformer_spec + +_SUPPORTED_ACTIVATIONS = { + "gelu": common_spec.Activation.GELU, + "relu": common_spec.Activation.RELU, + "swish": common_spec.Activation.SWISH, +} + +_SUPPORTED_POSTPROCESS_EMB = {"", "d", "n", "nd"} + + +class MarianConverter(Converter): + """Converts models trained with Marian.""" + + def __init__(self, model_path, vocab_paths): + self._model_path = model_path + self._vocab_paths = vocab_paths + + def _load(self): + model = np.load(self._model_path) + config = _get_model_config(model) + vocabs = list(map(_load_vocab, self._vocab_paths)) + + activation = config["transformer-ffn-activation"] + pre_norm = "n" in config["transformer-preprocess"] + postprocess_emb = config["transformer-postprocess-emb"] + + check = utils.ConfigurationChecker() + check(config["type"] == "transformer", "Option --type must be 'transformer'") + check( + config["transformer-decoder-autoreg"] == "self-attention", + "Option --transformer-decoder-autoreg must be 'self-attention'", + ) + check( + not config["transformer-no-projection"], + "Option --transformer-no-projection is not supported", + ) + check( + activation in _SUPPORTED_ACTIVATIONS, + "Option --transformer-ffn-activation %s is not supported " + "(supported activations are: %s)" + % (activation, ", ".join(_SUPPORTED_ACTIVATIONS.keys())), + ) + check( + postprocess_emb in _SUPPORTED_POSTPROCESS_EMB, + "Option --transformer-postprocess-emb %s is not supported (supported values are: %s)" + % (postprocess_emb, ", ".join(_SUPPORTED_POSTPROCESS_EMB)), + ) + + if pre_norm: + check( + config["transformer-preprocess"] == "n" + and config["transformer-postprocess"] == "da" + and config.get("transformer-postprocess-top", "") == "n", + "Unsupported pre-norm Transformer architecture, expected the following " + "combination of options: " + "--transformer-preprocess n " + "--transformer-postprocess da " + "--transformer-postprocess-top n", + ) + else: + check( + config["transformer-preprocess"] == "" + and config["transformer-postprocess"] == "dan" + and config.get("transformer-postprocess-top", "") == "", + "Unsupported post-norm Transformer architecture, excepted the following " + "combination of options: " + "--transformer-preprocess '' " + "--transformer-postprocess dan " + "--transformer-postprocess-top ''", + ) + + check.validate() + + alignment_layer = config["transformer-guided-alignment-layer"] + alignment_layer = -1 if alignment_layer == "last" else int(alignment_layer) - 1 + layernorm_embedding = "n" in postprocess_emb + + model_spec = transformer_spec.TransformerSpec( + (config["enc-depth"], config["dec-depth"]), + config["transformer-heads"], + pre_norm=pre_norm, + activation=_SUPPORTED_ACTIVATIONS[activation], + alignment_layer=alignment_layer, + alignment_heads=1, + layernorm_embedding=layernorm_embedding, + ) + set_transformer_spec(model_spec, model) + model_spec.register_source_vocabulary(vocabs[0]) + model_spec.register_target_vocabulary(vocabs[-1]) + model_spec.with_source_eos = True + return model_spec + + +def _get_model_config(model): + config = model["special:model.yml"] + config = config[:-1].tobytes() + config = yaml.safe_load(config) + return config + + +def _load_vocab(path): + # pyyaml skips some entries so we manually parse the vocabulary file. + with open(path, encoding="utf-8") as vocab: + tokens = [] + for i, line in enumerate(vocab): + line = line.rstrip("\n\r") + token, idx = line.rsplit(":", 1) + try: + int(idx.strip()) + except ValueError as e: + raise ValueError( + "Unexpected format at line %d: '%s'" % (i + 1, line) + ) from e + if token.startswith('"') and token.endswith('"'): + # Unescape characters and remove quotes. + token = re.sub(r"\\(.)", r"\1", token) + token = token[1:-1] + tokens.append(token) + return tokens + + +def set_transformer_spec(spec, weights): + set_transformer_encoder(spec.encoder, weights, "encoder") + set_transformer_decoder(spec.decoder, weights, "decoder") + + +def set_transformer_encoder(spec, weights, scope): + set_common_layers(spec, weights, scope) + for i, layer_spec in enumerate(spec.layer): + set_transformer_encoder_layer(layer_spec, weights, "%s_l%d" % (scope, i + 1)) + + +def set_transformer_decoder(spec, weights, scope): + spec.start_from_zero_embedding = True + set_common_layers(spec, weights, scope) + for i, layer_spec in enumerate(spec.layer): + set_transformer_decoder_layer(layer_spec, weights, "%s_l%d" % (scope, i + 1)) + + set_linear( + spec.projection, + weights, + "%s_ff_logit_out" % scope, + reuse_weight=spec.embeddings.weight, + ) + + +def set_common_layers(spec, weights, scope): + embeddings_specs = spec.embeddings + if not isinstance(embeddings_specs, list): + embeddings_specs = [embeddings_specs] + + set_embeddings(embeddings_specs[0], weights, scope) + set_position_encodings( + spec.position_encodings, weights, dim=embeddings_specs[0].weight.shape[1] + ) + if spec.layernorm_embedding != model_spec.OPTIONAL: + set_layer_norm(spec.layernorm_embedding, weights, "%s_emb" % scope) + if spec.layer_norm != model_spec.OPTIONAL: + set_layer_norm(spec.layer_norm, weights, "%s_top" % scope) + + +def set_transformer_encoder_layer(spec, weights, scope): + set_ffn(spec.ffn, weights, "%s_ffn" % scope) + set_multi_head_attention( + spec.self_attention, weights, "%s_self" % scope, self_attention=True + ) + + +def set_transformer_decoder_layer(spec, weights, scope): + set_ffn(spec.ffn, weights, "%s_ffn" % scope) + set_multi_head_attention( + spec.self_attention, weights, "%s_self" % scope, self_attention=True + ) + set_multi_head_attention(spec.attention, weights, "%s_context" % scope) + + +def set_multi_head_attention(spec, weights, scope, self_attention=False): + split_layers = [common_spec.LinearSpec() for _ in range(3)] + set_linear(split_layers[0], weights, scope, "q") + set_linear(split_layers[1], weights, scope, "k") + set_linear(split_layers[2], weights, scope, "v") + + if self_attention: + utils.fuse_linear(spec.linear[0], split_layers) + else: + spec.linear[0].weight = split_layers[0].weight + spec.linear[0].bias = split_layers[0].bias + utils.fuse_linear(spec.linear[1], split_layers[1:]) + + set_linear(spec.linear[-1], weights, scope, "o") + set_layer_norm(spec.layer_norm, weights, "%s_Wo" % scope) + + +def set_ffn(spec, weights, scope): + set_layer_norm(spec.layer_norm, weights, "%s_ffn" % scope) + set_linear(spec.linear_0, weights, scope, "1") + set_linear(spec.linear_1, weights, scope, "2") + + +def set_layer_norm(spec, weights, scope): + spec.gamma = weights["%s_ln_scale" % scope].squeeze() + spec.beta = weights["%s_ln_bias" % scope].squeeze() + + +def set_linear(spec, weights, scope, suffix="", reuse_weight=None): + weight = weights.get("%s_W%s" % (scope, suffix)) + spec.weight = reuse_weight if weight is None else weight.transpose() + + bias = weights.get("%s_b%s" % (scope, suffix)) + if bias is not None: + spec.bias = bias.squeeze() + + +def set_embeddings(spec, weights, scope): + spec.weight = weights.get("%s_Wemb" % scope) + if spec.weight is None: + spec.weight = weights.get("Wemb") + + +def set_position_encodings(spec, weights, dim=None): + spec.encodings = weights.get("Wpos", _make_sinusoidal_position_encodings(dim)) + + +def _make_sinusoidal_position_encodings(dim, num_positions=2048): + # Copied from https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/models/marian/modeling_marian.py # noqa: E501 + position_enc = np.array( + [ + [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] + for pos in range(num_positions) + ] + ) + table = np.zeros_like(position_enc) + table[:, : dim // 2] = np.sin(position_enc[:, 0::2]) + table[:, dim // 2 :] = np.cos(position_enc[:, 1::2]) + return table + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--model_path", required=True, help="Path to the model .npz file." + ) + parser.add_argument( + "--vocab_paths", + required=True, + nargs="+", + help="List of paths to the YAML vocabularies.", + ) + Converter.declare_arguments(parser) + args = parser.parse_args() + converter = MarianConverter(args.model_path, args.vocab_paths) + converter.convert_from_args(args) + + +if __name__ == "__main__": + main() diff --git a/python/ctranslate2/converters/opennmt_py.py b/python/ctranslate2/converters/opennmt_py.py index e22fe1692..d530465fa 100644 --- a/python/ctranslate2/converters/opennmt_py.py +++ b/python/ctranslate2/converters/opennmt_py.py @@ -23,34 +23,32 @@ def _get_model_spec(opt, num_source_embeddings): activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu") feat_merge = getattr(opt, "feat_merge", "concat") - reasons = [] - if opt.encoder_type != "transformer" or opt.decoder_type != "transformer": - reasons.append( - "Options --encoder_type and --decoder_type must be 'transformer'" - ) - if getattr(opt, "self_attn_type", "scaled-dot") != "scaled-dot": - reasons.append( - "Option --self_attn_type %s is not supported (supported values are: scaled-dot)" - % opt.self_attn_type - ) - if activation_fn not in _SUPPORTED_ACTIVATIONS: - reasons.append( - "Option --pos_ffn_activation_fn %s is not supported (supported activations are: %s)" - % (activation_fn, ", ".join(_SUPPORTED_ACTIVATIONS.keys())) - ) - if opt.position_encoding == with_relative_position: - reasons.append( - "Options --position_encoding and --max_relative_positions cannot be both enabled " - "or both disabled" - ) - if num_source_embeddings > 1 and feat_merge not in _SUPPORTED_FEATURES_MERGE: - reasons.append( - "Option --feat_merge %s is not supported (supported merge modes are: %s)" - % (feat_merge, " ".join(_SUPPORTED_FEATURES_MERGE.keys())) - ) - - if reasons: - utils.raise_unsupported(reasons) + check = utils.ConfigurationChecker() + check( + opt.encoder_type == "transformer" and opt.decoder_type == "transformer", + "Options --encoder_type and --decoder_type must be 'transformer'", + ) + check( + getattr(opt, "self_attn_type", "scaled-dot") == "scaled-dot", + "Option --self_attn_type %s is not supported (supported values are: scaled-dot)" + % opt.self_attn_type, + ) + check( + activation_fn in _SUPPORTED_ACTIVATIONS, + "Option --pos_ffn_activation_fn %s is not supported (supported activations are: %s)" + % (activation_fn, ", ".join(_SUPPORTED_ACTIVATIONS.keys())), + ) + check( + opt.position_encoding != with_relative_position, + "Options --position_encoding and --max_relative_positions cannot be both enabled " + "or both disabled", + ) + check( + num_source_embeddings == 1 or feat_merge in _SUPPORTED_FEATURES_MERGE, + "Option --feat_merge %s is not supported (supported merge modes are: %s)" + % (feat_merge, " ".join(_SUPPORTED_FEATURES_MERGE.keys())), + ) + check.validate() num_heads = getattr(opt, "heads", 8) return transformer_spec.TransformerSpec( diff --git a/python/ctranslate2/converters/opus_mt.py b/python/ctranslate2/converters/opus_mt.py new file mode 100644 index 000000000..0f2fdab49 --- /dev/null +++ b/python/ctranslate2/converters/opus_mt.py @@ -0,0 +1,39 @@ +import argparse +import os + +import yaml + +from ctranslate2.converters.marian import MarianConverter + + +class OpusMTConverter(MarianConverter): + """Converts models trained with OPUS-MT.""" + + def __init__(self, model_dir): + with open( + os.path.join(model_dir, "decoder.yml"), encoding="utf-8" + ) as decoder_file: + decoder_config = yaml.safe_load(decoder_file) + + model_path = os.path.join(model_dir, decoder_config["models"][0]) + vocab_paths = [ + os.path.join(model_dir, path) for path in decoder_config["vocabs"] + ] + super().__init__(model_path, vocab_paths) + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--model_dir", required=True, help="Path to the OPUS-MT model directory." + ) + OpusMTConverter.declare_arguments(parser) + args = parser.parse_args() + converter = OpusMTConverter(args.model_dir) + converter.convert_from_args(args) + + +if __name__ == "__main__": + main() diff --git a/python/ctranslate2/converters/utils.py b/python/ctranslate2/converters/utils.py index 3cd0e2598..fb58545a6 100644 --- a/python/ctranslate2/converters/utils.py +++ b/python/ctranslate2/converters/utils.py @@ -14,3 +14,16 @@ def raise_unsupported(reasons): for reason in reasons: message += "\n- " + reason raise ValueError(message) + + +class ConfigurationChecker: + def __init__(self): + self._unsupported_reasons = [] + + def __call__(self, assert_condition, error_message): + if not assert_condition: + self._unsupported_reasons.append(error_message) + + def validate(self): + if self._unsupported_reasons: + raise_unsupported(self._unsupported_reasons) diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index 4ba4cda44..6fa8c5206 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -115,6 +115,7 @@ def __init__(self, num_layers, pre_norm=True, layernorm_embedding=False): ) self.projection = common_spec.LinearSpec() self.layer = [TransformerDecoderLayerSpec() for _ in range(num_layers)] + self.start_from_zero_embedding = False class TransformerEncoderLayerSpec(model_spec.LayerSpec): diff --git a/python/setup.py b/python/setup.py index c033dcb63..b77811d61 100644 --- a/python/setup.py +++ b/python/setup.py @@ -99,12 +99,15 @@ def _maybe_add_library_root(lib_name): python_requires=">=3.6,<3.11", install_requires=[ "numpy", + "pyyaml>=5.3,<7", ], entry_points={ "console_scripts": [ "ct2-fairseq-converter=ctranslate2.converters.fairseq:main", + "ct2-marian-converter=ctranslate2.converters.marian:main", "ct2-opennmt-py-converter=ctranslate2.converters.opennmt_py:main", "ct2-opennmt-tf-converter=ctranslate2.converters.opennmt_tf:main", + "ct2-opus-mt-converter=ctranslate2.converters.opus_mt:main", ], }, ) diff --git a/python/tests/test.py b/python/tests/test.py index 754a07757..6cb816a7f 100644 --- a/python/tests/test.py +++ b/python/tests/test.py @@ -814,6 +814,17 @@ def _load(self): assert output[0].hypotheses[0] == ["a", "t", "z", "m", "o", "n"] +@skip_if_data_missing +def test_marian_model_conversion(tmpdir): + model_dir = os.path.join(_TEST_DATA_DIR, "models", "opus-mt-ende") + converter = ctranslate2.converters.OpusMTConverter(model_dir) + output_dir = str(tmpdir.join("ctranslate2_model")) + converter.convert(output_dir) + translator = ctranslate2.Translator(output_dir) + output = translator.translate_batch([["▁Hello", "▁world", "!"]]) + assert output[0].hypotheses[0] == ["▁Hallo", "▁Welt", "!"] + + def test_layer_spec_validate(): class SubSpec(ctranslate2.specs.LayerSpec): def __init__(self): diff --git a/python/tools/prepare_test_environment.sh b/python/tools/prepare_test_environment.sh index f923e5bd8..18e322269 100755 --- a/python/tools/prepare_test_environment.sh +++ b/python/tools/prepare_test_environment.sh @@ -10,3 +10,6 @@ pip uninstall -y ctranslate2 # Download test data curl -o transliteration-aren-all.tar.gz https://opennmt-models.s3.amazonaws.com/transliteration-aren-all.tar.gz tar xf transliteration-aren-all.tar.gz -C tests/data/models/ + +curl -O https://object.pouta.csc.fi/OPUS-MT-models/en-de/opus-2020-02-26.zip +unzip opus-2020-02-26.zip -d tests/data/models/opus-mt-ende diff --git a/src/layers/decoder.cc b/src/layers/decoder.cc index 74911c1f2..10a62e56d 100644 --- a/src/layers/decoder.cc +++ b/src/layers/decoder.cc @@ -5,6 +5,22 @@ namespace ctranslate2 { namespace layers { + void zero_first_timestep(StorageView& x, dim_t step) { + if (step == 0) { + x.zero(); + } else if (step < 0) { + // TODO: a more direct way to set the first timestep to 0. + const auto dtype = x.dtype(); + const auto device = x.device(); + StorageView first_step(dtype, device); + StorageView other_steps(dtype, device); + ops::Split(1, {1, x.dim(1) - 1})(x, first_step, other_steps); + first_step.zero(); + ops::Concat(1)({&first_step, &other_steps}, x); + } + } + + Decoder::Decoder(Device device) : _device(device) { } diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index c0d07a4d7..bd5653138 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -243,6 +243,8 @@ namespace ctranslate2 { , _num_heads(num_heads) , _compute_type(model.effective_compute_type()) , _embeddings(model, scope + "/embeddings") + , _start_from_zero_embedding(model.get_flag_with_default(scope + "/start_from_zero_embedding", + false)) , _embeddings_scale(build_embeddings_scale(model, scope, _embeddings)) , _position_encoder(with_position_encoding ? build_position_encoder(model, scope + "/position_encodings", _embeddings) @@ -332,7 +334,9 @@ namespace ctranslate2 { StorageView layer_out(output_type(), ids.device()); _embeddings(ids, layer_in); - if (_embeddings_scale) + if (_start_from_zero_embedding) + zero_first_timestep(layer_in, step); + if (_embeddings_scale && (!_start_from_zero_embedding || step != 0)) ops::Mul()(layer_in, *_embeddings_scale, layer_in); if (layer_in.rank() == 2) layer_in.expand_dims(1);