Skip to content

Commit 4f8a4f3

Browse files
authored
fix onmt converter (#1581)
1 parent 83caf67 commit 4f8a4f3

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

python/ctranslate2/converters/opennmt_py.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,10 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
104104
activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu")
105105
num_heads = getattr(opt, "heads", 8)
106106
num_kv = getattr(opt, "num_kv", 0)
107-
if num_kv == num_heads:
107+
if num_kv == num_heads or num_kv == 0:
108108
num_kv = None
109109
rotary_dim = 0 if with_rotary else None
110+
rotary_interleave = getattr(opt, "rotary_interleave", True)
110111
ffn_glu = activation_fn == "silu"
111112
sliding_window = getattr(opt, "sliding_window", 0)
112113

@@ -119,7 +120,7 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
119120
alibi=with_alibi,
120121
rms_norm=opt.layer_norm == "rms",
121122
rotary_dim=rotary_dim,
122-
rotary_interleave=True,
123+
rotary_interleave=rotary_interleave,
123124
multi_query_attention=getattr(opt, "multiquery", False),
124125
num_heads_kv=num_kv,
125126
sliding_window=sliding_window,
@@ -329,7 +330,7 @@ def set_linear(spec, variables, scope):
329330
spec.weight = _get_variable(variables, "%s.weight" % scope)
330331
bias = variables.get("%s.bias" % scope)
331332
if bias is not None:
332-
spec.bias = bias.numpy()
333+
spec.bias = bias
333334

334335

335336
def set_embeddings(spec, variables, scope):
@@ -341,7 +342,7 @@ def set_position_encodings(spec, variables, scope):
341342

342343

343344
def _get_variable(variables, name):
344-
return variables[name].numpy()
345+
return variables[name]
345346

346347

347348
def main():

0 commit comments

Comments
 (0)