Skip to content

Commit

Permalink
fix encoder/decoder errors
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Dec 17, 2024
1 parent 8833112 commit 64b9803
Showing 1 changed file with 32 additions and 15 deletions.
47 changes: 32 additions & 15 deletions python/ctranslate2/converters/eole_ct2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@ def _get_model_spec_seq2seq(
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Rotary
)
if with_rotary:
raise ValueError(
"Rotary embeddings are not supported yet for encoder/decoder models"
)
with_alibi = (
getattr(config.embeddings, "position_encoding_type", None)
== PositionEncodingType.Alibi
)
if with_alibi:
raise ValueError("Alibi is not supported yet for encoder/decoder models")
activation_fn = getattr(config, "mlp_activation_fn", "relu")

# Return the first head of the last layer unless the model was trained with alignments.
Expand All @@ -44,26 +50,30 @@ def _get_model_spec_seq2seq(
alignment_heads = config.decoder.alignment_heads

num_heads = getattr(config.decoder, "heads", 8)
num_kv = getattr(config.decoder, "heads_kv", 0)
if num_kv == num_heads or num_kv == 0:
num_kv = None
rotary_dim = 0 if with_rotary else None
rotary_interleave = getattr(config.rope_config, "rotary_interleave", True)
# num_kv = getattr(config.decoder, "heads_kv", 0)
# if num_kv == num_heads or num_kv == 0:
# num_kv = None
# rotary_dim = 0 if with_rotary else None
# rotary_interleave = getattr(config.rope_config, "rotary_interleave", True)
ffn_glu = activation_fn == "gated-silu"
sliding_window = getattr(config, "sliding_window", 0)
if sliding_window != 0:
raise ValueError(
"Sliding window is not suported yet for encoder/decoder models"
)

model_spec = transformer_spec.TransformerSpec.from_config(
(config.encoder.layers, config.decoder.layers),
num_heads,
with_relative_position=with_relative_position,
alibi=with_alibi,
# alibi=with_alibi,
activation=_SUPPORTED_ACTIVATIONS[activation_fn],
ffn_glu=ffn_glu,
rms_norm=config.layer_norm == "rms",
rotary_dim=rotary_dim,
rotary_interleave=rotary_interleave,
num_heads_kv=num_kv,
sliding_window=sliding_window,
# rotary_dim=rotary_dim,
# rotary_interleave=rotary_interleave,
# num_heads_kv=num_kv,
# sliding_window=sliding_window,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
num_source_embeddings=num_source_embeddings,
Expand Down Expand Up @@ -202,7 +212,7 @@ def set_transformer_spec(spec, variables):


def set_transformer_encoder(spec, variables):
set_input_layers(spec, variables, "encoder")
set_input_layers(spec, variables, "src_emb")
set_layer_norm(spec.layer_norm, variables, "encoder.layer_norm")
for i, layer in enumerate(spec.layer):
set_transformer_encoder_layer(layer, variables, "encoder.transformer.%d" % i)
Expand All @@ -227,7 +237,7 @@ def set_input_layers(spec, variables, scope):
set_position_encodings(
spec.position_encodings,
variables,
"%s.embeddings.pe" % scope,
"%s.pe" % scope,
)
else:
spec.scale_embeddings = False
Expand All @@ -236,14 +246,19 @@ def set_input_layers(spec, variables, scope):


def set_transformer_encoder_layer(spec, variables, scope):
set_ffn(spec.ffn, variables, "%s.feed_forward" % scope)
set_multi_head_attention(
spec.self_attention,
variables,
"%s.self_attn" % scope,
self_attention=True,
)
set_layer_norm(spec.self_attention.layer_norm, variables, "%s.layer_norm" % scope)
set_layer_norm(
spec.self_attention.layer_norm, variables, "%s.input_layernorm" % scope
)
set_layer_norm(
spec.ffn.layer_norm, variables, "%s.post_attention_layernorm" % scope
)
set_ffn(spec.ffn, variables, "%s.mlp" % scope)


def set_transformer_decoder_layer(spec, variables, scope, with_encoder_attention=True):
Expand All @@ -258,7 +273,9 @@ def set_transformer_decoder_layer(spec, variables, scope, with_encoder_attention
)
if with_encoder_attention:
set_multi_head_attention(spec.attention, variables, "%s.context_attn" % scope)
set_layer_norm(spec.attention.layer_norm, variables, "%s.layer_norm_2" % scope)
set_layer_norm(
spec.attention.layer_norm, variables, "%s.precontext_layernorm" % scope
)
set_layer_norm(
spec.ffn.layer_norm, variables, "%s.post_attention_layernorm" % scope
)
Expand Down

0 comments on commit 64b9803

Please sign in to comment.