Skip to content

Commit

Permalink
Merge branch 'master' into n_mels_param
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentin Berkes committed Nov 7, 2023
2 parents 261687c + d0a9227 commit 3836555
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
2 changes: 2 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,8 @@ def get_model_spec(self, model):
spec = whisper_spec.WhisperSpec(
model.config.encoder_layers,
model.config.encoder_attention_heads,
model.config.decoder_layers,
model.config.decoder_attention_heads,
)

self.set_encoder(spec.encoder, model.model.encoder)
Expand Down
20 changes: 15 additions & 5 deletions python/ctranslate2/specs/whisper_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,27 @@ def __init__(
class WhisperSpec(model_spec.LanguageModelSpec):
"""Describes a Whisper model."""

def __init__(self, num_layers, num_heads):
def __init__(
self,
num_encoder_layers,
num_encoder_heads,
num_decoder_layers,
num_decoder_heads,
):
"""Initializes the model specification.
Args:
num_layers: The number of encoder and decoder layers.
num_heads: The number of attention heads.
num_encoder_layers: The number of encoder layers.
num_encoder_heads: The number of encoder attention heads.
num_decoder_layers: The number of decoder layers.
num_decoder_heads: The number of decoder attention heads.
"""
super().__init__()
self.encoder = WhisperEncoderSpec(num_layers, num_heads)
self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads)
self.decoder = transformer_spec.TransformerDecoderSpec(
num_layers, num_heads, activation=common_spec.Activation.GELU
num_decoder_layers,
num_decoder_heads,
activation=common_spec.Activation.GELU,
)
self.decoder.scale_embeddings = False

Expand Down

0 comments on commit 3836555

Please sign in to comment.