Skip to content

Commit

Permalink
Add converters for Marian and OPUS-MT (#720)
Browse files Browse the repository at this point in the history
* Add converters for Marian and OPUS-MT

* Add missing backslash

* Decode bytes

* Fix encoding

* Fix encoding
  • Loading branch information
guillaumekln authored Feb 28, 2022
1 parent 25d9e35 commit dcb7888
Show file tree
Hide file tree
Showing 16 changed files with 456 additions and 66 deletions.
37 changes: 28 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ CTranslate2 is a fast and full-featured inference engine for Transformer models.
* [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py)
* [OpenNMT-tf](https://github.com/OpenNMT/OpenNMT-tf)
* [Fairseq](https://github.com/pytorch/fairseq/)
* [Marian](https://github.com/marian-nmt/marian)

The project is production-oriented and comes with [backward compatibility guarantees](#what-is-the-state-of-this-project), but it also includes experimental features related to model compression and inference acceleration.

Expand Down Expand Up @@ -66,7 +67,7 @@ pip install --upgrade pip
pip install ctranslate2
```

**2\. [Convert](#converting-models) a Transformer model trained with OpenNMT-py, OpenNMT-tf, or Fairseq:**
**2\. [Convert](#converting-models) a Transformer model trained with OpenNMT-py, OpenNMT-tf, Fairseq, or Marian:**

*a. OpenNMT-py*

Expand Down Expand Up @@ -106,18 +107,32 @@ ct2-fairseq-converter --model_path wmt16.en-de.joined-dict.transformer/model.pt
--output_dir ende_ctranslate2
```

*d. Marian*

```bash
wget https://object.pouta.csc.fi/OPUS-MT-models/en-de/opus-2020-02-26.zip
unzip opus-2020-02-26.zip

ct2-marian-converter --model_path opus.spm32k-spm32k.transformer-align.model1.npz.best-perplexity.npz \
--vocab_paths opus.spm32k-spm32k.vocab.yml opus.spm32k-spm32k.vocab.yml \
--output_dir ende_ctranslate2

# For OPUS-MT models, you can use ct2-opus-mt-converter instead:
ct2-opus-mt-converter --model_dir . --output_dir ende_ctranslate2
```

**3\. [Translate](#translating) tokenized inputs with the Python API:**

```python
import ctranslate2

translator = ctranslate2.Translator("ende_ctranslate2/", device="cpu")

# The OpenNMT-py and OpenNMT-tf models use a SentencePiece tokenization:
translator.translate_batch([["▁H", "ello", "▁world", "!"]])
batch = [["▁H", "ello", "▁world", "!"]] # OpenNMT model input
# batch = [["H@@", "ello", "world@@", "!"]] # Fairseq model input
# batch = [["▁Hello", "▁world", "!"]] # Marian model input

# The Fairseq model uses a BPE tokenization:
translator.translate_batch([["H@@", "ello", "world@@", "!"]])
translator.translate_batch(batch)
```

## Installation
Expand Down Expand Up @@ -163,10 +178,10 @@ The core CTranslate2 implementation is framework agnostic. The framework specifi

The following frameworks and models are currently supported:

| | OpenNMT-tf | OpenNMT-py | Fairseq |
| --- | :---: | :---: | :---: |
| Transformer ([Vaswani et al. 2017](https://arxiv.org/abs/1706.03762)) ||||
| + relative position representations ([Shaw et al. 2018](https://arxiv.org/abs/1803.02155)) ||| |
| | OpenNMT-tf | OpenNMT-py | Fairseq | Marian |
| --- | :---: | :---: | :---: | :---: |
| Transformer ([Vaswani et al. 2017](https://arxiv.org/abs/1706.03762)) |||||
| + relative position representations ([Shaw et al. 2018](https://arxiv.org/abs/1803.02155)) ||| | |

*If you are using a model that is not listed above, consider opening an issue to discuss future integration.*

Expand All @@ -175,6 +190,8 @@ The Python package includes a [conversion API](docs/python.md#model-conversion-a
* `ct2-opennmt-py-converter`
* `ct2-opennmt-tf-converter`
* `ct2-fairseq-converter`
* `ct2-marian-converter`
* `ct2-opus-mt-converter` (based on `ct2-marian-converter`)

The conversion should be run in the same environment as the selected training framework.

Expand Down Expand Up @@ -480,8 +497,10 @@ The implementation has been generously tested in [production environment](https:
* Python symbols:
* `ctranslate2.Translator`
* `ctranslate2.converters.FairseqConverter`
* `ctranslate2.converters.MarianConverter`
* `ctranslate2.converters.OpenNMTPyConverter`
* `ctranslate2.converters.OpenNMTTFConverter`
* `ctranslate2.converters.OpusMTConverter`
* C++ symbols:
* `ctranslate2::models::Model`
* `ctranslate2::TranslationOptions`
Expand Down
9 changes: 9 additions & 0 deletions docs/python.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ converter = ctranslate2.converters.FairseqConverter(
fixed_dictionary: str = None, # Path to the fixed dictionary for multilingual models.
)

converter = ctranslate2.converters.MarianConverter(
model_path: str, # Path to the Marian model (.npz file).
vocab_paths: List[str], # Paths to the vocabularies (.yml files).
)

converter = ctranslate2.converters.OpusMTConverter(
model_dir: str, # Path to the OPUS-MT model directory.
)

output_dir = converter.convert(
output_dir: str, # Path to the output directory.
vmap: str = None, # Path to a vocabulary mapping file.
Expand Down
2 changes: 2 additions & 0 deletions include/ctranslate2/layers/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace ctranslate2 {

using DecoderState = std::unordered_map<std::string, StorageView>;

void zero_first_timestep(StorageView& x, dim_t step);

// Base class for decoders.
class Decoder : public Layer {
public:
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ namespace ctranslate2 {
dim_t _alignment_heads;
const ComputeType _compute_type;
const Embeddings _embeddings;
const bool _start_from_zero_embedding;
const std::unique_ptr<const StorageView> _embeddings_scale;
const std::unique_ptr<PositionEncoder> _position_encoder;
const std::unique_ptr<LayerNorm> _layernorm_embedding;
Expand Down
2 changes: 2 additions & 0 deletions python/ctranslate2/converters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from ctranslate2.converters.converter import Converter
from ctranslate2.converters.fairseq import FairseqConverter
from ctranslate2.converters.marian import MarianConverter
from ctranslate2.converters.opennmt_py import OpenNMTPyConverter
from ctranslate2.converters.opennmt_tf import OpenNMTTFConverter
from ctranslate2.converters.opus_mt import OpusMTConverter
58 changes: 30 additions & 28 deletions python/ctranslate2/converters/fairseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,36 @@
def _get_model_spec(args):
activation_fn = getattr(args, "activation_fn", "relu")

reasons = []
if args.arch not in _SUPPORTED_ARCHS:
reasons.append(
"Option --arch %s is not supported (supported architectures are: %s)"
% (args.arch, ", ".join(_SUPPORTED_ARCHS))
)
if args.encoder_normalize_before != args.decoder_normalize_before:
reasons.append(
"Options --encoder-normalize-before and --decoder-normalize-before "
"must have the same value"
)
if args.encoder_attention_heads != args.decoder_attention_heads:
reasons.append(
"Options --encoder-attention-heads and --decoder-attention-heads must "
"have the same value"
)
if activation_fn not in _SUPPORTED_ACTIVATIONS.keys():
reasons.append(
"Option --activation-fn %s is not supported (supported activations are: %s)"
% (activation_fn, ", ".join(_SUPPORTED_ACTIVATIONS.keys()))
)
if getattr(args, "no_token_positional_embeddings", False):
reasons.append("Option --no-token-positional-embeddings is not supported")
if getattr(args, "lang_tok_replacing_bos_eos", False):
reasons.append("Option --lang-tok-replacing-bos-eos is not supported")

if reasons:
utils.raise_unsupported(reasons)
check = utils.ConfigurationChecker()
check(
args.arch in _SUPPORTED_ARCHS,
"Option --arch %s is not supported (supported architectures are: %s)"
% (args.arch, ", ".join(_SUPPORTED_ARCHS)),
)
check(
args.encoder_normalize_before == args.decoder_normalize_before,
"Options --encoder-normalize-before and --decoder-normalize-before "
"must have the same value",
)
check(
args.encoder_attention_heads == args.decoder_attention_heads,
"Options --encoder-attention-heads and --decoder-attention-heads "
"must have the same value",
)
check(
activation_fn in _SUPPORTED_ACTIVATIONS,
"Option --activation-fn %s is not supported (supported activations are: %s)"
% (activation_fn, ", ".join(_SUPPORTED_ACTIVATIONS.keys())),
)
check(
not getattr(args, "no_token_positional_embeddings", False),
"Option --no-token-positional-embeddings is not supported",
)
check(
not getattr(args, "lang_tok_replacing_bos_eos", False),
"Option --lang-tok-replacing-bos-eos is not supported",
)
check.validate()

return transformer_spec.TransformerSpec(
(args.encoder_layers, args.decoder_layers),
Expand Down
Loading

0 comments on commit dcb7888

Please sign in to comment.