Skip to content

Commit bb1cfc1

Browse files
committed
Fix conversion of models with shared embeddings (#120)
Variables should be aliased after checking the vocabulary sizes.
1 parent cd14ed3 commit bb1cfc1

File tree

3 files changed

+78
-5
lines changed

3 files changed

+78
-5
lines changed

python/ctranslate2/converters/converter.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def convert(self, output_dir, model_spec, vmap=None, quantization=None, force=Fa
5959
model_spec.validate()
6060
self._check_vocabulary_size("source", src_vocab, model_spec.source_vocabulary_size)
6161
self._check_vocabulary_size("target", tgt_vocab, model_spec.target_vocabulary_size)
62-
if quantization is not None:
63-
model_spec.quantize(quantization)
62+
model_spec.optimize(quantization=quantization)
6463
model_spec.serialize(os.path.join(output_dir, "model.bin"))
6564
if vmap is not None:
6665
shutil.copy(vmap, os.path.join(output_dir, "vmap.txt"))

python/ctranslate2/specs/model_spec.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def _check(spec, name, value):
6464
# Convert bool to an integer type.
6565
setattr(spec, attr_name, np.dtype("int8").type(value))
6666
self.visit(_check)
67-
self._alias_variables()
6867

6968
def variables(self, prefix="", ordered=False):
7069
"""Returns a dict mapping variables name to value. If ordered is True,
@@ -98,7 +97,7 @@ def _alias_variables(self):
9897
setattr(spec, attr_name, other_name)
9998
break
10099

101-
def quantize(self, quantization):
100+
def _quantize(self, quantization):
102101
"""Possibly quantizes the variable of the layer."""
103102
def _quantize(spec, name, value):
104103
if "weight" in name and isinstance(value, np.ndarray):
@@ -117,6 +116,12 @@ def _quantize(spec, name, value):
117116
setattr(spec, "weight", value)
118117
self.visit(_quantize)
119118

119+
def optimize(self, quantization=None):
120+
"""Applies some optimizations on this layer."""
121+
self._alias_variables()
122+
if quantization is not None:
123+
self._quantize(quantization)
124+
120125
def visit(self, fn):
121126
"""Recursively visits this layer and its children."""
122127
visit_spec(self, fn)

python/tests/test.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,21 @@ def test_opennmt_tf_model_conversion(tmpdir, model_path, src_vocab, tgt_vocab, m
130130
output = translator.translate_batch([["آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"]])
131131
assert output[0][0]["tokens"] == ["a", "t", "z", "m", "o", "n"]
132132

133+
@pytest.mark.skipif(not _FRAMEWORK_DATA_EXIST, reason="Data files are not available")
134+
@pytest.mark.parametrize("quantization", ["int16", "int8"])
135+
def test_opennmt_tf_model_quantization(tmpdir, quantization):
136+
model_path = os.path.join(
137+
_TEST_DATA_DIR, "models", "transliteration-aren-all", "opennmt_tf", "v2", "checkpoint")
138+
converter = ctranslate2.converters.OpenNMTTFConverter(
139+
model_path,
140+
src_vocab=os.path.join(model_path, "ar.vocab"),
141+
tgt_vocab=os.path.join(model_path, "en.vocab"))
142+
output_dir = str(tmpdir.join("ctranslate2_model"))
143+
converter.convert(output_dir, ctranslate2.specs.TransformerBase(), quantization=quantization)
144+
translator = ctranslate2.Translator(output_dir)
145+
output = translator.translate_batch([["آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن"]])
146+
assert output[0][0]["tokens"] == ["a", "t", "z", "m", "o", "n"]
147+
133148
@pytest.mark.skipif(not _FRAMEWORK_DATA_EXIST, reason="Data files are not available")
134149
def test_opennmt_tf_variables_conversion(tmpdir):
135150
model_path = os.path.join(
@@ -159,6 +174,43 @@ def test_opennmt_tf_model_conversion_invalid_vocab(tmpdir):
159174
with pytest.raises(ValueError):
160175
converter.convert(output_dir, ctranslate2.specs.TransformerBase())
161176

177+
def test_opennmt_tf_shared_embeddings_conversion(tmpdir):
178+
# Issue https://github.com/OpenNMT/CTranslate2/issues/118
179+
import tensorflow as tf
180+
import opennmt
181+
182+
vocab = opennmt.data.Vocab()
183+
for i in range(10):
184+
vocab.add(str(i))
185+
vocab_path = str(tmpdir.join("vocab.txt"))
186+
vocab.serialize(vocab_path)
187+
188+
num_layers = 3
189+
num_heads = 4
190+
model = opennmt.models.Transformer(
191+
opennmt.inputters.WordEmbedder(32),
192+
opennmt.inputters.WordEmbedder(32),
193+
num_layers,
194+
num_units=32,
195+
num_heads=num_heads,
196+
ffn_inner_dim=64,
197+
share_embeddings=opennmt.models.EmbeddingsSharingLevel.ALL)
198+
model.initialize({"source_vocabulary": vocab_path, "target_vocabulary": vocab_path})
199+
model.create_variables()
200+
201+
checkpoint_prefix = str(tmpdir.join("ckpt"))
202+
checkpoint = tf.train.Checkpoint(model=model)
203+
checkpoint.write(checkpoint_prefix)
204+
205+
converter = ctranslate2.converters.OpenNMTTFConverter(
206+
model_path=checkpoint_prefix, src_vocab=vocab_path, tgt_vocab=vocab_path)
207+
output_dir = str(tmpdir.join("ctranslate2_model"))
208+
converter.convert(output_dir, ctranslate2.specs.TransformerSpec(num_layers, num_heads))
209+
210+
# Check that the translation runs.
211+
translator = ctranslate2.Translator(output_dir)
212+
translator.translate_batch([["1", "2", "3"]], max_decoding_length=10)
213+
162214
@pytest.mark.skipif(not _FRAMEWORK_DATA_EXIST, reason="Data files are not available")
163215
def test_opennmt_py_model_conversion(tmpdir):
164216
model_path = os.path.join(
@@ -203,12 +255,29 @@ def __init__(self):
203255
spec = Spec()
204256
spec.validate()
205257
assert spec.a.dtype == np.float32
206-
assert spec.b == "a"
258+
assert spec.b.dtype == np.float32
207259
assert spec.c.dtype == np.int32
208260
assert spec.d == OPTIONAL
209261
assert spec.e.a.dtype == np.float32
210262
assert spec.f.dtype == np.int8
211263

264+
def test_layer_spec_optimize():
265+
266+
class Spec(ctranslate2.specs.LayerSpec):
267+
def __init__(self):
268+
self.a = np.ones([5], dtype=np.float32)
269+
self.b = np.ones([5], dtype=np.float32)
270+
self.c = np.zeros([5], dtype=np.int32)
271+
self.weight = np.ones([5, 4], dtype=np.float32)
272+
273+
spec = Spec()
274+
spec.optimize(quantization="int16")
275+
assert spec.a.dtype == np.float32
276+
assert spec.b == "a"
277+
assert spec.c.dtype == np.int32
278+
assert spec.weight.dtype == np.int16
279+
assert spec.weight_scale.dtype == np.float32
280+
212281
def test_index_spec():
213282
spec = ctranslate2.specs.TransformerBase()
214283
assert isinstance(

0 commit comments

Comments
 (0)