Skip to content

Commit

Permalink
Wav2Vec2 upgrade with Conv1D options
Browse files Browse the repository at this point in the history
  • Loading branch information
hkwon committed Aug 13, 2024
1 parent d202032 commit 2768558
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 116 deletions.
71 changes: 69 additions & 2 deletions include/ctranslate2/layers/wav2vec2.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,68 @@
namespace ctranslate2 {
namespace layers {

class Wav2Vec2LayerNormConvLayer0 : public Layer {
public:
Wav2Vec2LayerNormConvLayer0(const models::Model& model, const std::string& scope);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _conv.output_type();
}

dim_t output_size() const override {
return _conv.output_size();
}

private:
const Conv1D _conv;
const LayerNorm _output_norm;
const ops::Transpose _transpose;
const ops::GELU _gelu;
};

class Wav2Vec2LayerNormConvLayer : public Layer {
public:
Wav2Vec2LayerNormConvLayer(const models::Model& model, const std::string& scope);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _conv.output_type();
}

dim_t output_size() const override {
return _conv.output_size();
}

private:
const Conv1D _conv;
const LayerNorm _output_norm;
const ops::Transpose _transpose;
const ops::GELU _gelu;
};

class Wav2Vec2PosConvLayer : public Layer {
public:
Wav2Vec2PosConvLayer(const models::Model& model, const std::string& scope);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _conv.output_type();
}

dim_t output_size() const override {
return _conv.output_size();
}

private:
const Conv1D _conv;
const ops::Transpose _transpose;
const ops::GELU _gelu;
};

class Wav2Vec2Encoder : public Layer {
public:
Wav2Vec2Encoder(const models::Model& model, const std::string& scope);
Expand Down Expand Up @@ -35,12 +97,17 @@ namespace ctranslate2 {
}

private:
const Wav2Vec2LayerNormConvLayer0 _feat_layer0;
const std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>> _feat_layers;
const LayerNorm _fp_norm;
const Dense _fp_ff;
const Wav2Vec2PosConvLayer _pos_conv_embed;
const ops::Transpose _transpose;
const ops::GELU _gelu;
// wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
//const ops::Transpose _transpose;
const dim_t _num_heads;
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
const LayerNorm _output_norm;
const Dense _lm_head;
};

}
Expand Down
39 changes: 32 additions & 7 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,9 +992,8 @@ def architecture_name(self):
return "Wav2Vec2ForCTC"

def get_model_spec(self, model):
# Wav2Vec2 encoder Wav2Vec2PositionalConvEmbedding conv1d has groups 16
# that doesn't look available here so we make Wav2Vec2 encoder layers only
spec = wav2vec2_spec.Wav2Vec2Spec(
model.wav2vec2.config.num_feat_extract_layers,
model.wav2vec2.encoder.config.num_hidden_layers,
model.wav2vec2.encoder.config.num_attention_heads,
)
Expand All @@ -1007,9 +1006,7 @@ def get_model_spec(self, model):
layer.fc1 = layer.feed_forward.intermediate_dense
layer.fc2 = layer.feed_forward.output_dense

self.set_encoder(spec.encoder, model.wav2vec2.encoder)
self.set_linear(spec.lm_head, model.lm_head)
# only for Wav2Vec2Spec.get_vocabulary_size()
self.set_encoder(spec.encoder, model, model.wav2vec2.config)
return spec

def set_config(self, config, model, tokenizer):
Expand All @@ -1021,8 +1018,36 @@ def get_vocabulary(self, model, tokenizer):
def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_encoder(self, spec, encoder):
super().set_encoder(spec, encoder)
def set_feature_extractor(self, spec, feature_extractor):
spec.feat_layer0.conv.weight = feature_extractor.conv_layers[0].conv.weight
spec.feat_layer0.conv.bias = feature_extractor.conv_layers[0].conv.bias
self.set_layer_norm(spec.feat_layer0.layer_norm, feature_extractor.conv_layers[0].layer_norm)
for spec_layer, module_layer in zip(spec.feat_layer, feature_extractor.conv_layers[1:]):
spec_layer.conv.weight = module_layer.conv.weight
spec_layer.conv.bias = module_layer.conv.bias
self.set_layer_norm(spec_layer.layer_norm, module_layer.layer_norm)

def set_feature_projection(self, spec, feature_projection):
self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm)
self.set_linear(spec.fp_projection, feature_projection.projection)

def set_pos_conv_embed(self, spec, encoder, config):
# forcing parameters to be set because some transformers version initializes garbage numbers
# conv parameters are float16 so force float32 for the loading
encoder.pos_conv_embed.conv.weight.data = encoder.pos_conv_embed.conv.weight.data.float()
encoder.pos_conv_embed.conv.bias.data = encoder.pos_conv_embed.conv.bias.float()
for param in encoder.pos_conv_embed.parameters():
param.data = param.data.float()
tmp = encoder.pos_conv_embed(torch.randn((1,1,config.hidden_size)))
spec.pos_conv_embed.conv.weight = encoder.pos_conv_embed.conv.weight
spec.pos_conv_embed.conv.bias = encoder.pos_conv_embed.conv.bias

def set_encoder(self, spec, model, config):
self.set_feature_extractor(spec, model.wav2vec2.feature_extractor)
self.set_feature_projection(spec, model.wav2vec2.feature_projection)
self.set_pos_conv_embed(spec, model.wav2vec2.encoder, config)
super().set_encoder(spec, model.wav2vec2.encoder)
self.set_linear(spec.lm_head, model.lm_head)

def set_common_layers(self, spec, module):
self.set_layer_norm(spec.layer_norm, module.layer_norm)
Expand Down
35 changes: 28 additions & 7 deletions python/ctranslate2/specs/wav2vec2_spec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional, Tuple

import torch.nn as nn
import numpy as np

from ctranslate2.specs import common_spec, model_spec, transformer_spec
Expand All @@ -11,12 +12,14 @@ class Wav2Vec2Config(model_spec.ModelConfig):
def __init__(self):
return


class Wav2Vec2Spec(model_spec.LanguageModelSpec):
def __init__(self, num_layers, num_heads):
def __init__(self, feat_layers, num_layers, num_heads):
super().__init__()
self.encoder = Wav2Vec2EncoderSpec(num_layers, num_heads)
self.lm_head = common_spec.LinearSpec()
self.encoder = Wav2Vec2EncoderSpec(
feat_layers,
num_layers,
num_heads
)

@property
def name(self):
Expand All @@ -30,14 +33,32 @@ def get_default_config(self):
return Wav2Vec2Config()

def get_vocabulary_size(self):
return self.lm_head.weight.shape[0]
return self.encoder.lm_head.weight.shape[0]


class Wav2Vec2LayerNormConvLayer(model_spec.LayerSpec):
def __init__(self):
self.conv = common_spec.Conv1DSpec()
self.layer_norm = common_spec.LayerNormSpec()


class Wav2Vec2PosEmbedConvLayer(model_spec.LayerSpec):
def __init__(self):
self.conv = common_spec.Conv1DSpec()


class Wav2Vec2EncoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers, num_heads):
def __init__(self, feat_layers, num_layers, num_heads):
self.num_heads = np.dtype("int16").type(num_heads)
# wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
self.feat_layer0 = Wav2Vec2LayerNormConvLayer()
self.feat_layer = [
Wav2Vec2LayerNormConvLayer() for i in range(feat_layers-1)
]
self.fp_layer_norm = common_spec.LayerNormSpec()
self.fp_projection = common_spec.LinearSpec()
self.pos_conv_embed = Wav2Vec2PosEmbedConvLayer()
self.layer_norm = common_spec.LayerNormSpec()
self.layer = [
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
]
self.lm_head = common_spec.LinearSpec()
81 changes: 11 additions & 70 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,24 +979,15 @@ def test_transformers_wav2vec2(
)
output_dir = str(tmp_dir.join("ctranslate2_model"))
output_dir = converter.convert(output_dir)
# 24 x Wav2Vec2EncoderLayerStableLayerNorm converted & saved

w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(model_name)
del w2v2_model.wav2vec2.encoder.layers
del w2v2_model.wav2vec2.encoder.layer_norm
w2v2_model.save_pretrained(output_dir + "/wav2vec2_partial.bin")
w2v2_processor = transformers.Wav2Vec2Processor.from_pretrained(model_name)
torch.save(w2v2_processor, output_dir + "/wav2vec2_processor.bin")
w2v2_processor.save_pretrained(output_dir + "/wav2vec2_processor")
processor = transformers.AutoProcessor.from_pretrained(output_dir+"/wav2vec2_processor")
model = ctranslate2.models.Wav2Vec2(output_dir, device=device, device_index=[0], compute_type="int8")

device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu"
cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0))
w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(
output_dir + "/wav2vec2_partial.bin"
).to(device)
del w2v2_model.wav2vec2.encoder.layers
del w2v2_model.wav2vec2.encoder.layer_norm
w2v2_processor = torch.load(output_dir + "/wav2vec2_processor.bin")
ct2_w2v2_model = ctranslate2.models.Wav2Vec2(
model = ctranslate2.models.Wav2Vec2(
output_dir,
device=device,
device_index=[0],
Expand All @@ -1015,66 +1006,16 @@ def test_transformers_wav2vec2(
sampling_rate=16000,
).input_values

with torch.no_grad():
extract_features = w2v2_model.wav2vec2.feature_extractor(
input_values.to(w2v2_model.device)
).transpose(1, 2)
hidden_states, extract_features = w2v2_model.wav2vec2.feature_projection(
extract_features
)
position_embeddings = w2v2_model.wav2vec2.encoder.pos_conv_embed(
hidden_states
)
hidden_states = position_embeddings + hidden_states
# hidden_states = w2v2_model.encoder.dropout(hidden_states)
# Dropout(p=0.0, inplace=False) bypassed

if ct2_w2v2_model.device == "cuda":
hidden_states = hidden_states.cpu()
else:
hidden_states.numpy()

hidden_states = np.ascontiguousarray(hidden_states)
hidden_states = np.ascontiguousarray(input_values.unsqueeze(0))
hidden_states = ctranslate2.StorageView.from_array(hidden_states)
to_cpu = (
ct2_w2v2_model.device == "cuda" and len(ct2_w2v2_model.device_index) > 1
)
ct2_output = ct2_w2v2_model.encode(
hidden_states,
to_cpu=to_cpu,
) # 24 x Wav2Vec2EncoderLayerStableLayerNorm processed
if ct2_w2v2_model.device == "cuda":
hidden_states = torch.as_tensor(
ct2_output,
device=ct2_w2v2_model.device,
)
to_cpu = (model.device == "cuda" and len(model.device_index) > 1)
output = model.encode(hidden_states,to_cpu=to_cpu)
if model.device=="cuda":
logits = torch.as_tensor(output, device=ct2_model.device)[0]
else:
hidden_states = torch.as_tensor(
np.array(ct2_output),
dtype=torch.float32,
device=ct2_w2v2_model.device,
)

encoder_outputs = transformers.modeling_outputs.BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=None,
attentions=None,
)
hidden_states = encoder_outputs[0]
outputs = transformers.modeling_outputs.Wav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
hidden_states = outputs[0]
# hidden_states = w2v2_model.dropout(hidden_states)
# Dropout(p=0.0, inplace=False) bypassed

with torch.no_grad():
logits = w2v2_model.lm_head(hidden_states.to(torch.float32))[0]
logits = torch.as_tensor(np.array(output), dtype=torch.float32, device=model.device)[0]

predicted_ids = torch.argmax(logits, dim=-1)
transcription = w2v2_processor.decode(predicted_ids, output_word_offsets=True)
transcription = processor.decode(predicted_ids, output_word_offsets=True)

assert transcription[0] == expected_transcription[0]
Loading

0 comments on commit 2768558

Please sign in to comment.