Skip to content

Commit

Permalink
Wav2Vec2Bert ASR support
Browse files Browse the repository at this point in the history
  • Loading branch information
hkwon committed Sep 10, 2024
1 parent 14e4c4c commit f79a752
Show file tree
Hide file tree
Showing 24 changed files with 325 additions and 112 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ set(SOURCES
src/layers/decoder.cc
src/layers/transformer.cc
src/layers/wav2vec2.cc
src/layers/wav2vec2bert.cc
src/layers/whisper.cc
src/logging.cc
src/models/language_model.cc
Expand All @@ -136,6 +137,7 @@ set(SOURCES
src/models/sequence_to_sequence.cc
src/models/transformer.cc
src/models/wav2vec2.cc
src/models/wav2vec2bert.cc
src/models/whisper.cc
src/ops/activation.cc
src/ops/add.cc
Expand Down Expand Up @@ -182,6 +184,7 @@ set(SOURCES
src/ops/split.cc
src/ops/slide.cc
src/ops/sub.cc
src/ops/sigmoid.cc
src/ops/swish.cc
src/ops/tanh.cc
src/ops/tile.cc
Expand Down
14 changes: 14 additions & 0 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,20 @@ namespace ctranslate2 {
namespace layers {

StorageView make_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t max_position,
dim_t left_max_position,
dim_t right_max_position);

/*StorageView make_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t max_position);
StorageView make_relative_asymmetric_positions(dim_t queries_length,
dim_t keys_length,
dim_t left_max_position,
dim_t right_max_position);*/

class RotaryEmbeddings;
class Alibi;

Expand Down Expand Up @@ -53,8 +64,11 @@ namespace ctranslate2 {
dim_t beam_size = 1);
const StorageView* _relative_attention_bias;
const StorageView* _relative_position_keys;
const StorageView* _relative_asymmetric_position_keys;
const StorageView* _relative_position_values;
dim_t _maximum_relative_position;
dim_t _relative_left_max_position;
dim_t _relative_right_max_position;
const bool _merge_time_and_head_dims;
const dim_t _cache_time_dim;
};
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/ops/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace ctranslate2 {
GELU,
GELUSigmoid,
Tanh,
Sigmoid,
};

const UnaryOp& get_activation_op(ActivationType type);
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "split.h"
#include "squeeze.h"
#include "sub.h"
#include "sigmoid.h"
#include "swish.h"
#include "tile.h"
#include "topk.h"
Expand Down
2 changes: 2 additions & 0 deletions include/ctranslate2/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ namespace ctranslate2 {
template <typename T>
static void gelu_sigmoid(const T* x, T* y, dim_t size);
template <typename T>
static void sigmoid(const T* x, T* y, dim_t size);
template <typename T>
static void swish(const T* x, T* y, dim_t size);

static void compute_u8_compensation(const int8_t* b,
Expand Down
1 change: 1 addition & 0 deletions python/cpp/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,6 @@ PYBIND11_MODULE(_ext, m)
ctranslate2::python::register_encoder(m);
ctranslate2::python::register_whisper(m);
ctranslate2::python::register_wav2vec2(m);
ctranslate2::python::register_wav2vec2bert(m);
ctranslate2::python::register_mpi(m);
}
1 change: 1 addition & 0 deletions python/cpp/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace ctranslate2 {
void register_translator(py::module& m);
void register_whisper(py::module& m);
void register_wav2vec2(py::module& m);
void register_wav2vec2bert(py::module& m);
void register_mpi(py::module& m);

}
Expand Down
204 changes: 100 additions & 104 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
model_spec,
transformer_spec,
wav2vec2_spec,
wav2vec2bert_spec,
whisper_spec,
)

Expand Down Expand Up @@ -1059,6 +1060,105 @@ def set_common_layers(self, spec, module):
self.set_layer_norm(spec.layer_norm, module.layer_norm)


@register_loader("Wav2Vec2BertConfig")
class Wav2Vec2BertLoader(BartLoader):
@property
def architecture_name(self):
return "Wav2Vec2BertForCTC"

def get_model_spec(self, model):
spec = wav2vec2bert_spec.Wav2Vec2BertSpec(
model.wav2vec2_bert.config.num_adapter_layers,
model.wav2vec2_bert.config.num_hidden_layers,
)

self.set_encoder(spec.encoder, model)
return spec

def set_config(self, config, model, tokenizer):
return

def get_vocabulary(self, model, tokenizer):
return tokenizer.get_vocab()

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_feature_projection(self, spec, feature_projection):
self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm)
self.set_linear(spec.fp_projection, feature_projection.projection)

def set_attention(self, spec, attention, left_max_position=None, right_max_position=None):
split_layers = [common_spec.LinearSpec() for _ in range(3)]
self.set_linear(split_layers[0], attention.linear_q)
self.set_linear(split_layers[1], attention.linear_k)
self.set_linear(split_layers[2], attention.linear_v)
utils.fuse_linear(spec.linear[0], split_layers)
self.set_linear(spec.linear[-1], attention.linear_out)
if left_max_position or right_max_position:
spec.relative_asymmetric_position_keys = attention.distance_embedding.weight
spec.relative_left_max_position = np.dtype("int32").type(left_max_position)
spec.relative_right_max_position = np.dtype("int32").type(right_max_position)

def set_wav2vec2bert_encoder(self, spec_layers, layers, left_max_position, right_max_position):
for slayer, layer in zip(spec_layers, layers):
self.set_layer_norm(slayer.enc_ffn1_layer_norm, layer.ffn1_layer_norm)
self.set_linear(slayer.enc_ffn1.linear_0, layer.ffn1.intermediate_dense)
self.set_linear(slayer.enc_ffn1.linear_1, layer.ffn1.output_dense)
self.set_attention(slayer.enc_attn, layer.self_attn, left_max_position, right_max_position)
self.set_layer_norm(slayer.enc_attn_layer_norm, layer.self_attn_layer_norm)
self.set_layer_norm(
slayer.enc_conv_layer_norm, layer.conv_module.layer_norm
)
self.set_conv1d(slayer.enc_conv_pointwise_conv1, layer.conv_module.pointwise_conv1)
self.set_conv1d(slayer.enc_conv_depthwise_conv, layer.conv_module.depthwise_conv)
self.set_layer_norm(
slayer.enc_conv_depthwise_layer_norm,
layer.conv_module.depthwise_layer_norm,
)
self.set_conv1d(slayer.enc_conv_pointwise_conv2, layer.conv_module.pointwise_conv2)
self.set_layer_norm(slayer.enc_ffn2_layer_norm, layer.ffn2_layer_norm)
self.set_linear(slayer.enc_ffn2.linear_0, layer.ffn2.intermediate_dense)
self.set_linear(slayer.enc_ffn2.linear_1, layer.ffn2.output_dense)
self.set_layer_norm(slayer.enc_final_layer_norm, layer.final_layer_norm)

def set_wav2vec2bert_adapter(self, spec_layers, layers):
for slayer, layer in zip(spec_layers, layers):
self.set_layer_norm(
slayer.adpt_residual_layer_norm, layer.residual_layer_norm
)
self.set_conv1d(slayer.adpt_residual_conv, layer.residual_conv)
self.set_layer_norm(slayer.adpt_attn_layer_norm, layer.self_attn_layer_norm)
self.set_conv1d(slayer.adpt_attn_conv, layer.self_attn_conv)
self.set_attention(slayer.adpt_attn_layer, layer.self_attn)
self.set_layer_norm(slayer.adpt_ffn_layer_norm, layer.ffn_layer_norm)
self.set_linear(slayer.adpt_ffn.linear_0, layer.ffn.intermediate_dense)
self.set_linear(slayer.adpt_ffn.linear_1, layer.ffn.output_dense)

def set_encoder(self, spec, model):
self.set_feature_projection(spec, model.wav2vec2_bert.feature_projection)
self.set_wav2vec2bert_encoder(
spec.encoder_layers,
model.wav2vec2_bert.encoder.layers,
model.wav2vec2_bert.config.left_max_position_embeddings,
model.wav2vec2_bert.config.right_max_position_embeddings,
)
self.set_wav2vec2bert_adapter(
spec.adapter_layers, model.wav2vec2_bert.adapter.layers
)
self.set_linear(spec.lm_head, model.lm_head)

def set_conv1d(self, spec, module):
spec.weight = module.weight
if module.bias is not None:
spec.bias = module.bias

def set_layer_norm(self, spec, module):
spec.gamma = module.weight
if module.bias is not None:
spec.beta = module.bias


@register_loader("T5Config")
class T5Loader(ModelLoader):
@property
Expand Down Expand Up @@ -1421,110 +1521,6 @@ def set_decoder(self, spec, module):
gc.collect()


@register_loader("Gemma2Config")
class Gemma2Loader(ModelLoader):
@property
def architecture_name(self):
return "Gemma2ForCausalLM"

def get_model_spec(self, model):
num_layers = model.config.num_hidden_layers

num_heads = model.config.num_attention_heads
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
if num_heads_kv == num_heads:
num_heads_kv = None

activation_config = getattr(
model.config, "hidden_activation", "gelu_pytorch_tanh"
)

spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
activation=(
common_spec.Activation.GELU
if activation_config == "gelu"
else common_spec.Activation.GELUTanh
),
pre_norm=True,
ffn_glu=True,
rms_norm=True,
rotary_dim=0,
rotary_interleave=False,
rotary_base=getattr(model.config, "rope_theta", 10000),
num_heads_kv=num_heads_kv,
head_dim=model.config.head_dim,
pre_post_layer_norm=True,
)

self.set_decoder(spec.decoder, model.model)
self.set_linear(spec.decoder.projection, model.lm_head)
spec.decoder.embeddings.multiply_by_sqrt_depth = model.config.hidden_size**0.5
return spec

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)

extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append("<extra_id_%d>" % i)
if model.config.vocab_size < len(tokens):
tokens = tokens[: model.config.vocab_size]

return tokens

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_config(self, config, model, tokenizer):
config.bos_token = tokenizer.bos_token
config.eos_token = tokenizer.eos_token
config.unk_token = tokenizer.unk_token
config.layer_norm_epsilon = model.config.rms_norm_eps

def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight
spec.layer_norm_use_residual = True

def set_decoder(self, spec, module):
spec.scale_embeddings = True
spec.start_from_zero_embedding = False
self.set_embeddings(spec.embeddings, module.embed_tokens)
self.set_layer_norm(spec.layer_norm, module.norm)

for layer_spec, layer in zip(spec.layer, module.layers):
self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm)

self.set_layer_norm(
layer_spec.post_attention_layer_norm, layer.post_attention_layernorm
)

self.set_layer_norm(
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
)

self.set_layer_norm(
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
)

wq = layer.self_attn.q_proj.weight
wk = layer.self_attn.k_proj.weight
wv = layer.self_attn.v_proj.weight
wo = layer.self_attn.o_proj.weight

layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
layer_spec.self_attention.linear[1].weight = wo

self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)

delattr(layer, "self_attn")
delattr(layer, "mlp")
gc.collect()


@register_loader("LlamaConfig")
class LlamaLoader(ModelLoader):
@property
Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
try:
from ctranslate2._ext import (
Wav2Vec2,
Wav2Vec2Bert,
Whisper,
WhisperGenerationResult,
WhisperGenerationResultAsync,
Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
TransformerSpec,
)
from ctranslate2.specs.wav2vec2_spec import Wav2Vec2Spec
from ctranslate2.specs.wav2vec2bert_spec import Wav2Vec2BertSpec
from ctranslate2.specs.whisper_spec import WhisperSpec
6 changes: 6 additions & 0 deletions python/ctranslate2/specs/attention_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
self,
self_attention=False,
relative_position=False,
relative_asymmetric_position=False,
relative_attention_bias=False,
rms_norm=False,
rotary_dim=None,
Expand Down Expand Up @@ -47,6 +48,11 @@ def __init__(
self.relative_attention_bias = None
self.relative_attention_max_distance = None

if relative_asymmetric_position:
self.relative_asymmetric_position_keys = None
self.relative_left_max_position = None
self.relative_right_max_position = None

if original_max_position_embeddings != 0:
self.original_max_position_embeddings = np.dtype("int32").type(
original_max_position_embeddings
Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/specs/common_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Activation(enum.IntEnum):
GELU = 3
GELUSigmoid = 4
Tanh = 5
Sigmoid = 6


# This enum should match the C++ equivalent in include/ctranslate2/layers/common.h.
Expand Down
Loading

0 comments on commit f79a752

Please sign in to comment.