diff --git a/CMakeLists.txt b/CMakeLists.txt index 52610ac89..62fc33640 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -127,6 +127,7 @@ set(SOURCES src/layers/decoder.cc src/layers/transformer.cc src/layers/wav2vec2.cc + src/layers/wav2vec2bert.cc src/layers/whisper.cc src/logging.cc src/models/language_model.cc @@ -136,6 +137,7 @@ set(SOURCES src/models/sequence_to_sequence.cc src/models/transformer.cc src/models/wav2vec2.cc + src/models/wav2vec2bert.cc src/models/whisper.cc src/ops/activation.cc src/ops/add.cc @@ -182,6 +184,7 @@ set(SOURCES src/ops/split.cc src/ops/slide.cc src/ops/sub.cc + src/ops/sigmoid.cc src/ops/swish.cc src/ops/tanh.cc src/ops/tile.cc diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index 87b21f725..5035f75a2 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -7,9 +7,20 @@ namespace ctranslate2 { namespace layers { StorageView make_relative_positions(dim_t queries_length, + dim_t keys_length, + dim_t max_position, + dim_t left_max_position, + dim_t right_max_position); + + /*StorageView make_relative_positions(dim_t queries_length, dim_t keys_length, dim_t max_position); + StorageView make_relative_asymmetric_positions(dim_t queries_length, + dim_t keys_length, + dim_t left_max_position, + dim_t right_max_position);*/ + class RotaryEmbeddings; class Alibi; @@ -53,8 +64,11 @@ namespace ctranslate2 { dim_t beam_size = 1); const StorageView* _relative_attention_bias; const StorageView* _relative_position_keys; + const StorageView* _relative_asymmetric_position_keys; const StorageView* _relative_position_values; dim_t _maximum_relative_position; + dim_t _relative_left_max_position; + dim_t _relative_right_max_position; const bool _merge_time_and_head_dims; const dim_t _cache_time_dim; }; diff --git a/include/ctranslate2/ops/activation.h b/include/ctranslate2/ops/activation.h index f500fcf9e..a9bff98cd 100644 --- a/include/ctranslate2/ops/activation.h +++ b/include/ctranslate2/ops/activation.h @@ -13,6 +13,7 @@ namespace ctranslate2 { GELU, GELUSigmoid, Tanh, + Sigmoid, }; const UnaryOp& get_activation_op(ActivationType type); diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h index ed9db4265..2a735e394 100644 --- a/include/ctranslate2/ops/ops.h +++ b/include/ctranslate2/ops/ops.h @@ -22,6 +22,7 @@ #include "split.h" #include "squeeze.h" #include "sub.h" +#include "sigmoid.h" #include "swish.h" #include "tile.h" #include "topk.h" diff --git a/include/ctranslate2/primitives.h b/include/ctranslate2/primitives.h index bed80c8ff..571121554 100644 --- a/include/ctranslate2/primitives.h +++ b/include/ctranslate2/primitives.h @@ -181,6 +181,8 @@ namespace ctranslate2 { template static void gelu_sigmoid(const T* x, T* y, dim_t size); template + static void sigmoid(const T* x, T* y, dim_t size); + template static void swish(const T* x, T* y, dim_t size); static void compute_u8_compensation(const int8_t* b, diff --git a/python/cpp/module.cc b/python/cpp/module.cc index 4489d5314..550aea5b2 100644 --- a/python/cpp/module.cc +++ b/python/cpp/module.cc @@ -87,5 +87,6 @@ PYBIND11_MODULE(_ext, m) ctranslate2::python::register_encoder(m); ctranslate2::python::register_whisper(m); ctranslate2::python::register_wav2vec2(m); + ctranslate2::python::register_wav2vec2bert(m); ctranslate2::python::register_mpi(m); } diff --git a/python/cpp/module.h b/python/cpp/module.h index 9c9a9a2ff..71d4b3b29 100644 --- a/python/cpp/module.h +++ b/python/cpp/module.h @@ -18,6 +18,7 @@ namespace ctranslate2 { void register_translator(py::module& m); void register_whisper(py::module& m); void register_wav2vec2(py::module& m); + void register_wav2vec2bert(py::module& m); void register_mpi(py::module& m); } diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index cd8e8aef4..17ac66e5f 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -23,6 +23,7 @@ model_spec, transformer_spec, wav2vec2_spec, + wav2vec2bert_spec, whisper_spec, ) @@ -1059,6 +1060,105 @@ def set_common_layers(self, spec, module): self.set_layer_norm(spec.layer_norm, module.layer_norm) +@register_loader("Wav2Vec2BertConfig") +class Wav2Vec2BertLoader(BartLoader): + @property + def architecture_name(self): + return "Wav2Vec2BertForCTC" + + def get_model_spec(self, model): + spec = wav2vec2bert_spec.Wav2Vec2BertSpec( + model.wav2vec2_bert.config.num_adapter_layers, + model.wav2vec2_bert.config.num_hidden_layers, + ) + + self.set_encoder(spec.encoder, model) + return spec + + def set_config(self, config, model, tokenizer): + return + + def get_vocabulary(self, model, tokenizer): + return tokenizer.get_vocab() + + def set_vocabulary(self, spec, tokens): + spec.register_vocabulary(tokens) + + def set_feature_projection(self, spec, feature_projection): + self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm) + self.set_linear(spec.fp_projection, feature_projection.projection) + + def set_attention(self, spec, attention, left_max_position=None, right_max_position=None): + split_layers = [common_spec.LinearSpec() for _ in range(3)] + self.set_linear(split_layers[0], attention.linear_q) + self.set_linear(split_layers[1], attention.linear_k) + self.set_linear(split_layers[2], attention.linear_v) + utils.fuse_linear(spec.linear[0], split_layers) + self.set_linear(spec.linear[-1], attention.linear_out) + if left_max_position or right_max_position: + spec.relative_asymmetric_position_keys = attention.distance_embedding.weight + spec.relative_left_max_position = np.dtype("int32").type(left_max_position) + spec.relative_right_max_position = np.dtype("int32").type(right_max_position) + + def set_wav2vec2bert_encoder(self, spec_layers, layers, left_max_position, right_max_position): + for slayer, layer in zip(spec_layers, layers): + self.set_layer_norm(slayer.enc_ffn1_layer_norm, layer.ffn1_layer_norm) + self.set_linear(slayer.enc_ffn1.linear_0, layer.ffn1.intermediate_dense) + self.set_linear(slayer.enc_ffn1.linear_1, layer.ffn1.output_dense) + self.set_attention(slayer.enc_attn, layer.self_attn, left_max_position, right_max_position) + self.set_layer_norm(slayer.enc_attn_layer_norm, layer.self_attn_layer_norm) + self.set_layer_norm( + slayer.enc_conv_layer_norm, layer.conv_module.layer_norm + ) + self.set_conv1d(slayer.enc_conv_pointwise_conv1, layer.conv_module.pointwise_conv1) + self.set_conv1d(slayer.enc_conv_depthwise_conv, layer.conv_module.depthwise_conv) + self.set_layer_norm( + slayer.enc_conv_depthwise_layer_norm, + layer.conv_module.depthwise_layer_norm, + ) + self.set_conv1d(slayer.enc_conv_pointwise_conv2, layer.conv_module.pointwise_conv2) + self.set_layer_norm(slayer.enc_ffn2_layer_norm, layer.ffn2_layer_norm) + self.set_linear(slayer.enc_ffn2.linear_0, layer.ffn2.intermediate_dense) + self.set_linear(slayer.enc_ffn2.linear_1, layer.ffn2.output_dense) + self.set_layer_norm(slayer.enc_final_layer_norm, layer.final_layer_norm) + + def set_wav2vec2bert_adapter(self, spec_layers, layers): + for slayer, layer in zip(spec_layers, layers): + self.set_layer_norm( + slayer.adpt_residual_layer_norm, layer.residual_layer_norm + ) + self.set_conv1d(slayer.adpt_residual_conv, layer.residual_conv) + self.set_layer_norm(slayer.adpt_attn_layer_norm, layer.self_attn_layer_norm) + self.set_conv1d(slayer.adpt_attn_conv, layer.self_attn_conv) + self.set_attention(slayer.adpt_attn_layer, layer.self_attn) + self.set_layer_norm(slayer.adpt_ffn_layer_norm, layer.ffn_layer_norm) + self.set_linear(slayer.adpt_ffn.linear_0, layer.ffn.intermediate_dense) + self.set_linear(slayer.adpt_ffn.linear_1, layer.ffn.output_dense) + + def set_encoder(self, spec, model): + self.set_feature_projection(spec, model.wav2vec2_bert.feature_projection) + self.set_wav2vec2bert_encoder( + spec.encoder_layers, + model.wav2vec2_bert.encoder.layers, + model.wav2vec2_bert.config.left_max_position_embeddings, + model.wav2vec2_bert.config.right_max_position_embeddings, + ) + self.set_wav2vec2bert_adapter( + spec.adapter_layers, model.wav2vec2_bert.adapter.layers + ) + self.set_linear(spec.lm_head, model.lm_head) + + def set_conv1d(self, spec, module): + spec.weight = module.weight + if module.bias is not None: + spec.bias = module.bias + + def set_layer_norm(self, spec, module): + spec.gamma = module.weight + if module.bias is not None: + spec.beta = module.bias + + @register_loader("T5Config") class T5Loader(ModelLoader): @property @@ -1421,110 +1521,6 @@ def set_decoder(self, spec, module): gc.collect() -@register_loader("Gemma2Config") -class Gemma2Loader(ModelLoader): - @property - def architecture_name(self): - return "Gemma2ForCausalLM" - - def get_model_spec(self, model): - num_layers = model.config.num_hidden_layers - - num_heads = model.config.num_attention_heads - num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads) - if num_heads_kv == num_heads: - num_heads_kv = None - - activation_config = getattr( - model.config, "hidden_activation", "gelu_pytorch_tanh" - ) - - spec = transformer_spec.TransformerDecoderModelSpec.from_config( - num_layers, - num_heads, - activation=( - common_spec.Activation.GELU - if activation_config == "gelu" - else common_spec.Activation.GELUTanh - ), - pre_norm=True, - ffn_glu=True, - rms_norm=True, - rotary_dim=0, - rotary_interleave=False, - rotary_base=getattr(model.config, "rope_theta", 10000), - num_heads_kv=num_heads_kv, - head_dim=model.config.head_dim, - pre_post_layer_norm=True, - ) - - self.set_decoder(spec.decoder, model.model) - self.set_linear(spec.decoder.projection, model.lm_head) - spec.decoder.embeddings.multiply_by_sqrt_depth = model.config.hidden_size**0.5 - return spec - - def get_vocabulary(self, model, tokenizer): - tokens = super().get_vocabulary(model, tokenizer) - - extra_ids = model.config.vocab_size - len(tokens) - for i in range(extra_ids): - tokens.append("" % i) - if model.config.vocab_size < len(tokens): - tokens = tokens[: model.config.vocab_size] - - return tokens - - def set_vocabulary(self, spec, tokens): - spec.register_vocabulary(tokens) - - def set_config(self, config, model, tokenizer): - config.bos_token = tokenizer.bos_token - config.eos_token = tokenizer.eos_token - config.unk_token = tokenizer.unk_token - config.layer_norm_epsilon = model.config.rms_norm_eps - - def set_layer_norm(self, spec, layer_norm): - spec.gamma = layer_norm.weight - spec.layer_norm_use_residual = True - - def set_decoder(self, spec, module): - spec.scale_embeddings = True - spec.start_from_zero_embedding = False - self.set_embeddings(spec.embeddings, module.embed_tokens) - self.set_layer_norm(spec.layer_norm, module.norm) - - for layer_spec, layer in zip(spec.layer, module.layers): - self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm) - - self.set_layer_norm( - layer_spec.post_attention_layer_norm, layer.post_attention_layernorm - ) - - self.set_layer_norm( - layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm - ) - - self.set_layer_norm( - layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm - ) - - wq = layer.self_attn.q_proj.weight - wk = layer.self_attn.k_proj.weight - wv = layer.self_attn.v_proj.weight - wo = layer.self_attn.o_proj.weight - - layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv]) - layer_spec.self_attention.linear[1].weight = wo - - self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj) - self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj) - self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj) - - delattr(layer, "self_attn") - delattr(layer, "mlp") - gc.collect() - - @register_loader("LlamaConfig") class LlamaLoader(ModelLoader): @property diff --git a/python/ctranslate2/models/__init__.py b/python/ctranslate2/models/__init__.py index aba612a5c..35a3dca37 100644 --- a/python/ctranslate2/models/__init__.py +++ b/python/ctranslate2/models/__init__.py @@ -5,6 +5,7 @@ try: from ctranslate2._ext import ( Wav2Vec2, + Wav2Vec2Bert, Whisper, WhisperGenerationResult, WhisperGenerationResultAsync, diff --git a/python/ctranslate2/specs/__init__.py b/python/ctranslate2/specs/__init__.py index 22552f5c9..b4e53fad2 100644 --- a/python/ctranslate2/specs/__init__.py +++ b/python/ctranslate2/specs/__init__.py @@ -14,4 +14,5 @@ TransformerSpec, ) from ctranslate2.specs.wav2vec2_spec import Wav2Vec2Spec +from ctranslate2.specs.wav2vec2bert_spec import Wav2Vec2BertSpec from ctranslate2.specs.whisper_spec import WhisperSpec diff --git a/python/ctranslate2/specs/attention_spec.py b/python/ctranslate2/specs/attention_spec.py index 2180d779b..f49d41121 100644 --- a/python/ctranslate2/specs/attention_spec.py +++ b/python/ctranslate2/specs/attention_spec.py @@ -19,6 +19,7 @@ def __init__( self, self_attention=False, relative_position=False, + relative_asymmetric_position=False, relative_attention_bias=False, rms_norm=False, rotary_dim=None, @@ -47,6 +48,11 @@ def __init__( self.relative_attention_bias = None self.relative_attention_max_distance = None + if relative_asymmetric_position: + self.relative_asymmetric_position_keys = None + self.relative_left_max_position = None + self.relative_right_max_position = None + if original_max_position_embeddings != 0: self.original_max_position_embeddings = np.dtype("int32").type( original_max_position_embeddings diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py index b1162839c..598a452d6 100644 --- a/python/ctranslate2/specs/common_spec.py +++ b/python/ctranslate2/specs/common_spec.py @@ -13,6 +13,7 @@ class Activation(enum.IntEnum): GELU = 3 GELUSigmoid = 4 Tanh = 5 + Sigmoid = 6 # This enum should match the C++ equivalent in include/ctranslate2/layers/common.h. diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index 3c35445fa..7d536eae7 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -1023,3 +1023,83 @@ def test_transformers_wav2vec2( transcription = transcription[0].replace(processor.tokenizer.unk_token, "") assert transcription == expected_transcription[0] + + +class TestWav2Vec2Bert: + @classmethod + def teardown_class(cls): + clear_transformers_cache_in_ci() + + @test_utils.only_on_linux + @test_utils.on_available_devices + @pytest.mark.parametrize( + "model_name,expected_transcription", + [ + ( + "hf-audio/wav2vec2-bert-CV16-en", + [ + "mr quilter is the apostle of the middle classes and" + " we are glad to welcome his gospel" + ], + ), + ], + ) + def test_transformers_wav2vec2bert( + self, + tmp_dir, + device, + model_name, + expected_transcription, + ): + import torch + import transformers + + converter = ctranslate2.converters.TransformersConverter( + model_name, load_as_float16="int8" + ) + output_dir = str(tmp_dir.join("ctranslate2_model")) + output_dir = converter.convert(output_dir) + + w2v2_processor = transformers.Wav2Vec2BertProcessor.from_pretrained(model_name) + w2v2_processor.save_pretrained(output_dir + "/wav2vec2_processor") + processor = transformers.AutoProcessor.from_pretrained( + output_dir + "/wav2vec2_processor" + ) + + device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" + cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0)) + model = ctranslate2.models.Wav2Vec2Bert( + output_dir, + device=device, + device_index=[0], + compute_type="int8", + intra_threads=cpu_threads, + inter_threads=1, + ) + + speech_array = np.load( + os.path.join(test_utils.get_data_dir(), "audio", "mr_quilter.npy") + ) + input_values = processor( + [speech_array], + padding=True, + return_tensors="pt", + sampling_rate=16000, + ).input_features + + hidden_states = np.ascontiguousarray(input_values.unsqueeze) + hidden_states = ctranslate2.StorageView.from_array(hidden_states) + to_cpu = model.device == "cuda" and len(model.device_index) > 1 + output = model.encode(hidden_states, to_cpu=to_cpu) + if model.device == "cuda": + logits = torch.as_tensor(output, device=model.device)[0] + else: + logits = torch.as_tensor( + np.array(output), dtype=torch.float32, device=model.device + )[0] + + predicted_ids = torch.argmax(logits, dim=-1) + transcription = processor.decode(predicted_ids, output_word_offsets=True) + transcription = transcription[0].replace(processor.tokenizer.unk_token, "") + + assert transcription == expected_transcription[0] diff --git a/src/cpu/kernels.cc b/src/cpu/kernels.cc index 2371704ec..c1f48553d 100644 --- a/src/cpu/kernels.cc +++ b/src/cpu/kernels.cc @@ -184,6 +184,13 @@ namespace ctranslate2 { } }; + struct sigmoid_func { + vec_type operator()(vec_type v) const { + using VecType = Vec; + return VecType::div(VecType::load(1.f), VecType::add(VecType::load(1.f), VecType::exp(VecType::neg(v)))); + } + }; + struct swish_func { vec_type operator()(vec_type v) const { using VecType = Vec; @@ -244,6 +251,11 @@ namespace ctranslate2 { vectorized_unary_transform(x, y, size, gelu_sigmoid_func()); } + template<> + void sigmoid(const float* x, float* y, dim_t size) { + vectorized_unary_transform(x, y, size, sigmoid_func()); + } + template<> void swish(const float* x, float* y, dim_t size) { vectorized_unary_transform(x, y, size, swish_func()); @@ -696,6 +708,9 @@ namespace ctranslate2 { case ops::ActivationType::GELUSigmoid: dequantize_gemm_output_row(c, a_scale, b_scale, bias, m, y, gelu_sigmoid_func()); break; + case ops::ActivationType::Sigmoid: + dequantize_gemm_output_row(c, a_scale, b_scale, bias, m, y, sigmoid_func()); + break; case ops::ActivationType::Swish: dequantize_gemm_output_row(c, a_scale, b_scale, bias, m, y, swish_func()); break; diff --git a/src/cpu/kernels.h b/src/cpu/kernels.h index 71d52cc67..16296fc36 100644 --- a/src/cpu/kernels.h +++ b/src/cpu/kernels.h @@ -27,6 +27,8 @@ namespace ctranslate2 { template void gelu_sigmoid(const float* x, float* y, dim_t size); template + void sigmoid(const float* x, float* y, dim_t size); + template void swish(const float* x, float* y, dim_t size); template diff --git a/src/cpu/primitives.cc b/src/cpu/primitives.cc index 0c6377bbb..5e0fd2999 100644 --- a/src/cpu/primitives.cc +++ b/src/cpu/primitives.cc @@ -313,6 +313,15 @@ namespace ctranslate2 { }); } + template<> + template<> + void primitives::sigmoid(const float* x, float* y, dim_t size) { + cpu::parallel_for(0, size, cpu::GRAIN_SIZE / 10, + [x, y](dim_t begin, dim_t end) { + CPU_ISA_DISPATCH((cpu::sigmoid(x + begin, y + begin, end - begin))); + }); + } + template<> template<> void primitives::swish(const float* x, float* y, dim_t size) { diff --git a/src/cuda/helpers.h b/src/cuda/helpers.h index a34d5d892..391fae73f 100644 --- a/src/cuda/helpers.h +++ b/src/cuda/helpers.h @@ -255,6 +255,14 @@ namespace ctranslate2 { } }; + template + struct sigmoid_func { + // Implicitly promote half to float in this function. + __device__ float operator()(float x) const { + return 1.f / (1.f + expf(-x)); + } + }; + template struct swish_func { // Implicitly promote half to float in this function. diff --git a/src/cuda/primitives.cu b/src/cuda/primitives.cu index 149e10dbb..9915bb12c 100644 --- a/src/cuda/primitives.cu +++ b/src/cuda/primitives.cu @@ -218,6 +218,12 @@ namespace ctranslate2 { cuda::unary_transform(x, y, size, cuda::gelu_sigmoid_func>()); } + template<> + template + void primitives::sigmoid(const T* x, T* y, dim_t size) { + cuda::unary_transform(x, y, size, cuda::sigmoid_func>()); + } + template<> template void primitives::swish(const T* x, T* y, dim_t size) { @@ -789,6 +795,7 @@ namespace ctranslate2 { template void primitives::gelu(const T*, T*, dim_t); \ template void primitives::gelu_tanh(const T*, T*, dim_t); \ template void primitives::gelu_sigmoid(const T*, T*, dim_t); \ + template void primitives::sigmoid(const T*, T*, dim_t); \ template void primitives::swish(const T*, T*, dim_t); \ template float primitives::logsumexp(const T*, dim_t); \ template void primitives::sin(const T*, T*, dim_t); \ diff --git a/src/layers/attention.cc b/src/layers/attention.cc index a206bcd05..dffcafa56 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -15,16 +15,22 @@ namespace ctranslate2 { StorageView make_relative_positions(dim_t queries_length, dim_t keys_length, - dim_t max_position) { + dim_t max_position, + dim_t left_max_position, + dim_t right_max_position) { StorageView positions({queries_length, keys_length}, DataType::INT32); auto* positions_data = positions.data(); const dim_t offset = keys_length - queries_length; + bool asymmetric = (left_max_position != 0 || right_max_position != 0); for (dim_t i = 0; i < queries_length; ++i) { auto* row = positions_data + i * keys_length; for (dim_t j = 0; j < keys_length; ++j) { - row[j] = std::min(std::max(j - (i + offset), -max_position), max_position) + max_position; + if (asymmetric) + row[j] = std::max(std::min(j - i, right_max_position), -left_max_position) + left_max_position; + else + row[j] = std::min(std::max(j - (i + offset), -max_position), max_position) + max_position; } } @@ -163,8 +169,11 @@ namespace ctranslate2 { const StorageView& values, const StorageView* values_lengths, const StorageView* relative_position_keys, + const StorageView* relative_asymmetric_position_keys, const StorageView* relative_position_values, const StorageView* relative_attention_bias, + dim_t relative_left_max_position, + dim_t relative_right_max_position, dim_t maximum_relative_position, StorageView& output, StorageView* attention = nullptr, @@ -178,13 +187,22 @@ namespace ctranslate2 { PROFILE("dot_product_attention"); std::unique_ptr relative_positions; - if (relative_position_keys || relative_position_values) { + if (relative_position_keys || relative_position_values || relative_asymmetric_position_keys) { const dim_t query_length = queries.dim(2); const dim_t key_length = keys.dim(2); - relative_positions = std::make_unique( - make_relative_positions(query_length, - key_length, - maximum_relative_position).to(queries.device())); + if (relative_asymmetric_position_keys) + relative_positions = std::make_unique( + make_relative_positions(query_length, + key_length, + /*maximum_relative_position=*/0, + relative_left_max_position, + relative_right_max_position).to(queries.device())); + else relative_positions = std::make_unique( + make_relative_positions(query_length, + key_length, + maximum_relative_position, + /*relative_left_max_position=*/0, + /*relative_right_max_position=*/0).to(queries.device())); } const ops::MatMul keys_matmul(/*trans_a=*/false, /*trans_b=*/true, queries_scale); @@ -196,6 +214,12 @@ namespace ctranslate2 { keys_matmul, output); + if (relative_asymmetric_position_keys) + add_relative_representations(queries, + *relative_positions, + *relative_asymmetric_position_keys, + keys_matmul, + output); if (relative_attention_bias) { StorageView local_position_bias(output.dtype(), output.device()); @@ -269,6 +293,7 @@ namespace ctranslate2 { : AttentionLayer(model, scope, num_heads, self_attention, pre_norm, is_decoder, alibi, false) , _relative_attention_bias(model.get_variable_if_exists(scope + "/relative_attention_bias")) , _relative_position_keys(model.get_variable_if_exists(scope + "/relative_position_keys")) + , _relative_asymmetric_position_keys(model.get_variable_if_exists(scope + "/relative_asymmetric_position_keys")) , _relative_position_values(model.get_variable_if_exists(scope + "/relative_position_values")) , _merge_time_and_head_dims(_multi_query && !_relative_attention_bias @@ -278,6 +303,12 @@ namespace ctranslate2 { { if (_relative_position_keys) _maximum_relative_position = (_relative_position_keys->dim(0) - 1) / 2; + else if (_relative_asymmetric_position_keys) { + _relative_left_max_position = model.get_attribute( + scope + "/relative_left_max_position"); + _relative_right_max_position = model.get_attribute( + scope + "/relative_right_max_position"); + } else if (_relative_attention_bias) _maximum_relative_position = model.get_attribute( scope + "/relative_attention_max_distance"); @@ -363,7 +394,7 @@ namespace ctranslate2 { if (queries_padder) queries_padder->add_padding(fused_proj); - const ops::Split split_op(2, {_num_heads * _d_head, _num_heads_kv * _d_head, _num_heads_kv * _d_head}); + const ops::Split split_op(2, {_d_model, _num_heads_kv * _d_head, _num_heads_kv * _d_head}); split_op(fused_proj, queries_proj, keys_proj, values_proj); if (_merge_time_and_head_dims) { @@ -432,8 +463,11 @@ namespace ctranslate2 { values_proj, values_lengths, _relative_position_keys, + _relative_asymmetric_position_keys, _relative_position_values, _relative_attention_bias, + _relative_left_max_position, + _relative_right_max_position, _maximum_relative_position, context, attention, diff --git a/src/models/model_factory.cc b/src/models/model_factory.cc index 488e0b8b2..059051f5d 100644 --- a/src/models/model_factory.cc +++ b/src/models/model_factory.cc @@ -4,6 +4,7 @@ #include "ctranslate2/models/whisper.h" #include "ctranslate2/models/wav2vec2.h" +#include "ctranslate2/models/wav2vec2bert.h" #include "ctranslate2/models/transformer.h" namespace ctranslate2 { @@ -23,6 +24,8 @@ namespace ctranslate2 { register_model("WhisperSpec"); register_model("Wav2Vec2Spec"); + + register_model("Wav2Vec2BertSpec"); } std::shared_ptr create_model(const std::string& name) { diff --git a/src/ops/activation.cc b/src/ops/activation.cc index df2bfe6c1..5de89ffb3 100644 --- a/src/ops/activation.cc +++ b/src/ops/activation.cc @@ -2,6 +2,7 @@ #include "ctranslate2/ops/gelu.h" #include "ctranslate2/ops/relu.h" +#include "ctranslate2/ops/sigmoid.h" #include "ctranslate2/ops/swish.h" #include "ctranslate2/ops/tanh.h" @@ -26,6 +27,10 @@ namespace ctranslate2 { static const GELU gelu(GELU::Approximation::Sigmoid); return gelu; } + case ActivationType::Sigmoid: { + static const Sigmoid sigmoid; + return sigmoid; + } case ActivationType::Swish: { static const Swish swish; return swish; diff --git a/src/ops/bias_add_gpu.cu b/src/ops/bias_add_gpu.cu index 8f53bcf64..951a7671e 100644 --- a/src/ops/bias_add_gpu.cu +++ b/src/ops/bias_add_gpu.cu @@ -61,6 +61,11 @@ namespace ctranslate2 { x, b, y, depth, cuda::plus(), cuda::gelu_sigmoid_func()); break; + case ActivationType::Sigmoid: + bias_add_kernel<<>>( + x, b, y, depth, cuda::plus(), cuda::sigmoid_func()); + break; + case ActivationType::Swish: bias_add_kernel<<>>( x, b, y, depth, cuda::plus(), cuda::swish_func()); diff --git a/src/ops/dequantize_gpu.cu b/src/ops/dequantize_gpu.cu index 241b3acdb..d14ae7efb 100644 --- a/src/ops/dequantize_gpu.cu +++ b/src/ops/dequantize_gpu.cu @@ -98,6 +98,12 @@ namespace ctranslate2 { break; } + case ActivationType::Sigmoid: { + dequantize_gemm_output_kernel<<>>( + c, a_scales, b_scales, transpose_a, transpose_b, bias, cuda::sigmoid_func(), y, depth); + break; + } + case ActivationType::Swish: { dequantize_gemm_output_kernel<<>>( c, a_scales, b_scales, transpose_a, transpose_b, bias, cuda::swish_func(), y, depth); diff --git a/tests/ops_test.cc b/tests/ops_test.cc index 7d7b376fa..c9369fa67 100644 --- a/tests/ops_test.cc +++ b/tests/ops_test.cc @@ -883,6 +883,17 @@ TEST_P(OpDeviceFPTest, GELUSigmoid) { expect_storage_eq(output.to_float32(), expected, error); } +TEST_P(OpDeviceFPTest, Sigmoid) { + const Device device = GetParam().device; + const DataType dtype = GetParam().dtype; + const float error = GetParam().error; + StorageView input({2}, std::vector{0.2, -1.3}, device); + StorageView expected({2}, std::vector{0.54983395, 0.21416503}, device); + StorageView output(dtype, device); + ops::Sigmoid()(input.to(dtype), output); + expect_storage_eq(output.to_float32(), expected, error); +} + TEST_P(OpDeviceFPTest, Swish) { const Device device = GetParam().device; const DataType dtype = GetParam().dtype;