From a1c112f19fc06308f7b562e321bc7a9ff19e135b Mon Sep 17 00:00:00 2001 From: hkwon Date: Wed, 14 Aug 2024 11:01:34 -0700 Subject: [PATCH] update based on the suggestions --- include/ctranslate2/layers/wav2vec2.h | 30 ++++++----------------- src/layers/wav2vec2.cc | 34 +++++++++------------------ 2 files changed, 18 insertions(+), 46 deletions(-) diff --git a/include/ctranslate2/layers/wav2vec2.h b/include/ctranslate2/layers/wav2vec2.h index 2d6a7d12c..29dea9783 100644 --- a/include/ctranslate2/layers/wav2vec2.h +++ b/include/ctranslate2/layers/wav2vec2.h @@ -5,30 +5,12 @@ namespace ctranslate2 { namespace layers { - class Wav2Vec2LayerNormConvLayer0 : public Layer { - public: - Wav2Vec2LayerNormConvLayer0(const models::Model& model, const std::string& scope); - - void operator()(const StorageView& input, StorageView& output) const; - - DataType output_type() const override { - return _conv.output_type(); - } - - dim_t output_size() const override { - return _conv.output_size(); - } - - private: - const Conv1D _conv; - const LayerNorm _output_norm; - const ops::Transpose _transpose; - const ops::GELU _gelu; - }; - class Wav2Vec2LayerNormConvLayer : public Layer { public: - Wav2Vec2LayerNormConvLayer(const models::Model& model, const std::string& scope); + Wav2Vec2LayerNormConvLayer(const models::Model& model, + const std::string& scope, + dim_t stride, + dim_t padding); void operator()(const StorageView& input, StorageView& output) const; @@ -41,6 +23,8 @@ namespace ctranslate2 { } private: + dim_t _stride; + dim_t _padding; const Conv1D _conv; const LayerNorm _output_norm; const ops::Transpose _transpose; @@ -97,7 +81,7 @@ namespace ctranslate2 { } private: - const Wav2Vec2LayerNormConvLayer0 _feat_layer0; + const Wav2Vec2LayerNormConvLayer _feat_layer0; const std::vector> _feat_layers; const LayerNorm _fp_norm; const Dense _fp_ff; diff --git a/src/layers/wav2vec2.cc b/src/layers/wav2vec2.cc index 04ccb077e..defbf0d84 100644 --- a/src/layers/wav2vec2.cc +++ b/src/layers/wav2vec2.cc @@ -1,30 +1,16 @@ #include "ctranslate2/layers/wav2vec2.h" -#include namespace ctranslate2 { namespace layers { - Wav2Vec2LayerNormConvLayer0::Wav2Vec2LayerNormConvLayer0(const models::Model& model, const std::string& scope) - : _conv(model, scope + "/conv", /*stride=*/5, /*padding=*/0) - , _transpose({0, 2, 1}) - , _output_norm(model, scope + "/layer_norm") { - } - - void Wav2Vec2LayerNormConvLayer0::operator()(const StorageView& input, StorageView& output) const{ - PROFILE("Wav2Vec2LayerNormConvLayer0"); - - StorageView buffer(input.dtype(), input.device()); - buffer = std::move(input); - _conv(buffer, output); - _transpose(output, buffer); - _output_norm(buffer, output); - _transpose(output, buffer); - _gelu(buffer, output); - } - - Wav2Vec2LayerNormConvLayer::Wav2Vec2LayerNormConvLayer(const models::Model& model, const std::string& scope) - : _conv(model, scope + "/conv", /*stride=*/2, /*padding=*/0) + Wav2Vec2LayerNormConvLayer::Wav2Vec2LayerNormConvLayer(const models::Model& model, + const std::string& scope, + dim_t stride, + dim_t padding) + : _stride(stride) + , _padding(padding) + , _conv(model, scope + "/conv", _stride, _padding) , _transpose({0, 2, 1}) , _output_norm(model, scope + "/layer_norm") { } @@ -60,9 +46,11 @@ namespace ctranslate2 { } Wav2Vec2Encoder::Wav2Vec2Encoder(const models::Model& model, const std::string& scope) - : _feat_layer0(model, scope + "/feat_layer0") + : _feat_layer0(model, scope + "/feat_layer0", /*stride=*/5, /*padding=*/0) , _feat_layers(build_layers_list(model, - scope + "/feat_layer")) + scope + "/feat_layer", + /*stride=*/2, + /*padding=*/0)) , _fp_norm(model, scope + "/fp_layer_norm") , _fp_ff(model, scope + "/fp_projection", nullptr, true) , _pos_conv_embed(model, scope + "/pos_conv_embed")