Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wav2Vec2Bert ASR Inference Support #1778

Merged
merged 30 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ set(SOURCES
src/layers/decoder.cc
src/layers/transformer.cc
src/layers/wav2vec2.cc
src/layers/wav2vec2bert.cc
src/layers/whisper.cc
src/logging.cc
src/models/language_model.cc
Expand All @@ -136,6 +137,7 @@ set(SOURCES
src/models/sequence_to_sequence.cc
src/models/transformer.cc
src/models/wav2vec2.cc
src/models/wav2vec2bert.cc
src/models/whisper.cc
src/ops/activation.cc
src/ops/add.cc
Expand Down Expand Up @@ -182,6 +184,7 @@ set(SOURCES
src/ops/split.cc
src/ops/slide.cc
src/ops/sub.cc
src/ops/sigmoid.cc
src/ops/swish.cc
src/ops/tanh.cc
src/ops/tile.cc
Expand Down
8 changes: 8 additions & 0 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ namespace ctranslate2 {
dim_t keys_length,
dim_t max_position);

StorageView make_asymmetric_relative_positions(dim_t queries_length,
dim_t keys_length,
dim_t left_max_position,
dim_t right_max_position);

class RotaryEmbeddings;
class Alibi;

Expand Down Expand Up @@ -53,8 +58,11 @@ namespace ctranslate2 {
dim_t beam_size = 1);
const StorageView* _relative_attention_bias;
const StorageView* _relative_position_keys;
const StorageView* _relative_asymmetric_position_keys;
const StorageView* _relative_position_values;
dim_t _maximum_relative_position;
dim_t _relative_left_max_position;
dim_t _relative_right_max_position;
const bool _merge_time_and_head_dims;
const dim_t _cache_time_dim;
};
Expand Down
126 changes: 126 additions & 0 deletions include/ctranslate2/layers/wav2vec2bert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#pragma once

#include "ctranslate2/layers/attention.h"
#include "ctranslate2/layers/flash_attention.h"
#include "ctranslate2/layers/common.h"
#include "ctranslate2/layers/transformer.h"
#include "ctranslate2/padder.h"

namespace ctranslate2 {
namespace layers {

class EncoderLayer : public Layer {
public:
EncoderLayer(const models::Model& model,
const std::string& scope,
const bool pre_norm = true,
const ops::ActivationType activation_type = ops::ActivationType::ReLU,
const bool use_flash_attention = false);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _final_layer_norm.output_type();
}

dim_t output_size() const override {
return _final_layer_norm.output_size();
}

const AttentionLayer& get_self_attention() const {
return *_self_attention;
}

private:
const dim_t _num_heads;
const LayerNorm _ffn1_layer_norm;
const FeedForwardNetwork _ff1;
const LayerNorm _self_attn_layer_norm;
std::unique_ptr<AttentionLayer> _self_attention;
const ops::Transpose _transpose;
const LayerNorm _layer_norm;
const Conv1D _pconv1;
const ops::Sigmoid _sigmoid;
const Conv1D _dconv;
const LayerNorm _dlayer_norm;
const ops::Swish _swish;
const Conv1D _pconv2;
const LayerNorm _ffn2_layer_norm;
const FeedForwardNetwork _ff2;
const LayerNorm _final_layer_norm;
};

class AdapterLayer : public Layer {
public:
AdapterLayer(const models::Model& model,
const std::string& scope,
const bool pre_norm = true,
const ops::ActivationType activation_type = ops::ActivationType::ReLU,
const bool use_flash_attention = false);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _ffn.output_type();
}

dim_t output_size() const override {
return _ffn.output_size();
}

const AttentionLayer& get_self_attention() const {
return *_self_attention;
}

private:
const dim_t _num_heads;
const LayerNorm _residual_layer_norm;
const ops::Transpose _transpose;
const Conv1D _residual_conv;
const ops::Sigmoid _sigmoid;
const LayerNorm _attn_layer_norm;
const Conv1D _attn_conv;
std::unique_ptr<AttentionLayer> _self_attention;
const LayerNorm _ffn_layer_norm;
const FeedForwardNetwork _ffn;
};

class Wav2Vec2BertEncoder : public Layer {
public:
Wav2Vec2BertEncoder(const models::Model& model, const std::string& scope);

void operator()(const StorageView& features, StorageView& output);

DataType output_type() const override {
return _lm_head.output_type();
}

dim_t output_size() const override {
return _lm_head.output_size();
}

dim_t input_size() const {
return 1024;
}

bool is_encoded(const StorageView& features) const {
// Input features shape: [batch_size, input_size, input_time]
// Encoder output shape: [batch_size, input_time // 2, output_size]
//
// input_time is variable so we check that dimension 1 is different than its original value.

return (features.rank() == 3
&& features.dim(2) == output_size()
&& features.dim(1) != input_size());
}

private:
const LayerNorm _fp_layer_norm;
const Dense _fp_projection;
const std::vector<std::unique_ptr<const EncoderLayer>> _encoder_layers;
const std::vector<std::unique_ptr<const AdapterLayer>> _adapt_layers;
const Dense _lm_head;
};

}
}
71 changes: 71 additions & 0 deletions include/ctranslate2/models/wav2vec2bert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#pragma once

#include "ctranslate2/layers/wav2vec2bert.h"
#include "ctranslate2/models/model.h"
#include "ctranslate2/replica_pool.h"

namespace ctranslate2 {
namespace models {

struct Wav2Vec2BertOptions {
// Maximum generation length.
size_t max_length = 448;

// Randomly sample from the top K candidates (set 0 to sample from the full distribution).
size_t sampling_topk = 1;

// Maximum index of the first predicted timestamp.
size_t max_initial_timestamp_index = 50;

// Suppress blank outputs at the beginning of the sampling.
bool suppress_blank = true;

// List of token IDs to suppress.
// -1 will suppress a default set of symbols as defined in the model config.json file.
std::vector<int> suppress_tokens = {-1};
};


class Wav2Vec2BertModel : public Model {
public:
const Vocabulary& get_vocabulary() const;
size_t current_spec_revision() const override;
bool is_quantizable(const std::string& variable_name) const override;
bool is_linear_weight(const std::string& variable_name) const override;
std::unique_ptr<Model> clone() const override;

bool use_global_int16_scale() const override {
return false;
}

protected:
void initialize(ModelReader& model_reader) override;
private:
std::shared_ptr<const Vocabulary> _vocabulary;
};

class Wav2Vec2BertReplica : public ModelReplica {
public:
static std::unique_ptr<Wav2Vec2BertReplica> create_from_model(const Model& model);

Wav2Vec2BertReplica(const std::shared_ptr<const Wav2Vec2BertModel>& model);

StorageView encode(StorageView features, const bool to_cpu);

private:
const std::shared_ptr<const Wav2Vec2BertModel> _model;
const std::unique_ptr<layers::Wav2Vec2BertEncoder> _encoder;

StorageView maybe_encode(StorageView features);
};

class Wav2Vec2Bert : public ReplicaPool<Wav2Vec2BertReplica> {
public:
using ReplicaPool::ReplicaPool;

std::future<StorageView> encode(const StorageView& features, const bool to_cpu);

};

}
}
1 change: 1 addition & 0 deletions include/ctranslate2/ops/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace ctranslate2 {
GELU,
GELUSigmoid,
Tanh,
Sigmoid,
};

const UnaryOp& get_activation_op(ActivationType type);
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "split.h"
#include "squeeze.h"
#include "sub.h"
#include "sigmoid.h"
#include "swish.h"
#include "tile.h"
#include "topk.h"
Expand Down
21 changes: 21 additions & 0 deletions include/ctranslate2/ops/sigmoid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include "op.h"

namespace ctranslate2 {
namespace ops {

class Sigmoid : public UnaryOp {
public:
void operator()(const StorageView& x, StorageView& y) const override;

private:
template <Device D, typename T>
void compute(const StorageView& x, StorageView& y) const {
y.resize_as(x);
primitives<D>::sigmoid(x.data<T>(), y.data<T>(), x.size());
}
};

}
}
2 changes: 2 additions & 0 deletions include/ctranslate2/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ namespace ctranslate2 {
template <typename T>
static void gelu_sigmoid(const T* x, T* y, dim_t size);
template <typename T>
static void sigmoid(const T* x, T* y, dim_t size);
template <typename T>
static void swish(const T* x, T* y, dim_t size);

static void compute_u8_compensation(const int8_t* b,
Expand Down
1 change: 1 addition & 0 deletions python/cpp/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,6 @@ PYBIND11_MODULE(_ext, m)
ctranslate2::python::register_encoder(m);
ctranslate2::python::register_whisper(m);
ctranslate2::python::register_wav2vec2(m);
ctranslate2::python::register_wav2vec2bert(m);
ctranslate2::python::register_mpi(m);
}
1 change: 1 addition & 0 deletions python/cpp/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace ctranslate2 {
void register_translator(py::module& m);
void register_whisper(py::module& m);
void register_wav2vec2(py::module& m);
void register_wav2vec2bert(py::module& m);
void register_mpi(py::module& m);

}
Expand Down
Loading
Loading