@@ -130,6 +130,21 @@ def test_opennmt_tf_model_conversion(tmpdir, model_path, src_vocab, tgt_vocab, m
130
130
output = translator .translate_batch ([["آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن" ]])
131
131
assert output [0 ][0 ]["tokens" ] == ["a" , "t" , "z" , "m" , "o" , "n" ]
132
132
133
+ @pytest .mark .skipif (not _FRAMEWORK_DATA_EXIST , reason = "Data files are not available" )
134
+ @pytest .mark .parametrize ("quantization" , ["int16" , "int8" ])
135
+ def test_opennmt_tf_model_quantization (tmpdir , quantization ):
136
+ model_path = os .path .join (
137
+ _TEST_DATA_DIR , "models" , "transliteration-aren-all" , "opennmt_tf" , "v2" , "checkpoint" )
138
+ converter = ctranslate2 .converters .OpenNMTTFConverter (
139
+ model_path ,
140
+ src_vocab = os .path .join (model_path , "ar.vocab" ),
141
+ tgt_vocab = os .path .join (model_path , "en.vocab" ))
142
+ output_dir = str (tmpdir .join ("ctranslate2_model" ))
143
+ converter .convert (output_dir , ctranslate2 .specs .TransformerBase (), quantization = quantization )
144
+ translator = ctranslate2 .Translator (output_dir )
145
+ output = translator .translate_batch ([["آ" ,"ت" ,"ز" ,"م" ,"و" ,"ن" ]])
146
+ assert output [0 ][0 ]["tokens" ] == ["a" , "t" , "z" , "m" , "o" , "n" ]
147
+
133
148
@pytest .mark .skipif (not _FRAMEWORK_DATA_EXIST , reason = "Data files are not available" )
134
149
def test_opennmt_tf_variables_conversion (tmpdir ):
135
150
model_path = os .path .join (
@@ -159,6 +174,43 @@ def test_opennmt_tf_model_conversion_invalid_vocab(tmpdir):
159
174
with pytest .raises (ValueError ):
160
175
converter .convert (output_dir , ctranslate2 .specs .TransformerBase ())
161
176
177
+ def test_opennmt_tf_shared_embeddings_conversion (tmpdir ):
178
+ # Issue https://github.com/OpenNMT/CTranslate2/issues/118
179
+ import tensorflow as tf
180
+ import opennmt
181
+
182
+ vocab = opennmt .data .Vocab ()
183
+ for i in range (10 ):
184
+ vocab .add (str (i ))
185
+ vocab_path = str (tmpdir .join ("vocab.txt" ))
186
+ vocab .serialize (vocab_path )
187
+
188
+ num_layers = 3
189
+ num_heads = 4
190
+ model = opennmt .models .Transformer (
191
+ opennmt .inputters .WordEmbedder (32 ),
192
+ opennmt .inputters .WordEmbedder (32 ),
193
+ num_layers ,
194
+ num_units = 32 ,
195
+ num_heads = num_heads ,
196
+ ffn_inner_dim = 64 ,
197
+ share_embeddings = opennmt .models .EmbeddingsSharingLevel .ALL )
198
+ model .initialize ({"source_vocabulary" : vocab_path , "target_vocabulary" : vocab_path })
199
+ model .create_variables ()
200
+
201
+ checkpoint_prefix = str (tmpdir .join ("ckpt" ))
202
+ checkpoint = tf .train .Checkpoint (model = model )
203
+ checkpoint .write (checkpoint_prefix )
204
+
205
+ converter = ctranslate2 .converters .OpenNMTTFConverter (
206
+ model_path = checkpoint_prefix , src_vocab = vocab_path , tgt_vocab = vocab_path )
207
+ output_dir = str (tmpdir .join ("ctranslate2_model" ))
208
+ converter .convert (output_dir , ctranslate2 .specs .TransformerSpec (num_layers , num_heads ))
209
+
210
+ # Check that the translation runs.
211
+ translator = ctranslate2 .Translator (output_dir )
212
+ translator .translate_batch ([["1" , "2" , "3" ]], max_decoding_length = 10 )
213
+
162
214
@pytest .mark .skipif (not _FRAMEWORK_DATA_EXIST , reason = "Data files are not available" )
163
215
def test_opennmt_py_model_conversion (tmpdir ):
164
216
model_path = os .path .join (
@@ -203,12 +255,29 @@ def __init__(self):
203
255
spec = Spec ()
204
256
spec .validate ()
205
257
assert spec .a .dtype == np .float32
206
- assert spec .b == "a"
258
+ assert spec .b . dtype == np . float32
207
259
assert spec .c .dtype == np .int32
208
260
assert spec .d == OPTIONAL
209
261
assert spec .e .a .dtype == np .float32
210
262
assert spec .f .dtype == np .int8
211
263
264
+ def test_layer_spec_optimize ():
265
+
266
+ class Spec (ctranslate2 .specs .LayerSpec ):
267
+ def __init__ (self ):
268
+ self .a = np .ones ([5 ], dtype = np .float32 )
269
+ self .b = np .ones ([5 ], dtype = np .float32 )
270
+ self .c = np .zeros ([5 ], dtype = np .int32 )
271
+ self .weight = np .ones ([5 , 4 ], dtype = np .float32 )
272
+
273
+ spec = Spec ()
274
+ spec .optimize (quantization = "int16" )
275
+ assert spec .a .dtype == np .float32
276
+ assert spec .b == "a"
277
+ assert spec .c .dtype == np .int32
278
+ assert spec .weight .dtype == np .int16
279
+ assert spec .weight_scale .dtype == np .float32
280
+
212
281
def test_index_spec ():
213
282
spec = ctranslate2 .specs .TransformerBase ()
214
283
assert isinstance (
0 commit comments