Skip to content

Commit

Permalink
update based on the suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
hkwon committed Aug 14, 2024
1 parent d844362 commit a1c112f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 46 deletions.
30 changes: 7 additions & 23 deletions include/ctranslate2/layers/wav2vec2.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,12 @@
namespace ctranslate2 {
namespace layers {

class Wav2Vec2LayerNormConvLayer0 : public Layer {
public:
Wav2Vec2LayerNormConvLayer0(const models::Model& model, const std::string& scope);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _conv.output_type();
}

dim_t output_size() const override {
return _conv.output_size();
}

private:
const Conv1D _conv;
const LayerNorm _output_norm;
const ops::Transpose _transpose;
const ops::GELU _gelu;
};

class Wav2Vec2LayerNormConvLayer : public Layer {
public:
Wav2Vec2LayerNormConvLayer(const models::Model& model, const std::string& scope);
Wav2Vec2LayerNormConvLayer(const models::Model& model,
const std::string& scope,
dim_t stride,
dim_t padding);

void operator()(const StorageView& input, StorageView& output) const;

Expand All @@ -41,6 +23,8 @@ namespace ctranslate2 {
}

private:
dim_t _stride;
dim_t _padding;
const Conv1D _conv;
const LayerNorm _output_norm;
const ops::Transpose _transpose;
Expand Down Expand Up @@ -97,7 +81,7 @@ namespace ctranslate2 {
}

private:
const Wav2Vec2LayerNormConvLayer0 _feat_layer0;
const Wav2Vec2LayerNormConvLayer _feat_layer0;
const std::vector<std::unique_ptr<const Wav2Vec2LayerNormConvLayer>> _feat_layers;
const LayerNorm _fp_norm;
const Dense _fp_ff;
Expand Down
34 changes: 11 additions & 23 deletions src/layers/wav2vec2.cc
Original file line number Diff line number Diff line change
@@ -1,30 +1,16 @@
#include "ctranslate2/layers/wav2vec2.h"
#include <iostream>


namespace ctranslate2 {
namespace layers {

Wav2Vec2LayerNormConvLayer0::Wav2Vec2LayerNormConvLayer0(const models::Model& model, const std::string& scope)
: _conv(model, scope + "/conv", /*stride=*/5, /*padding=*/0)
, _transpose({0, 2, 1})
, _output_norm(model, scope + "/layer_norm") {
}

void Wav2Vec2LayerNormConvLayer0::operator()(const StorageView& input, StorageView& output) const{
PROFILE("Wav2Vec2LayerNormConvLayer0");

StorageView buffer(input.dtype(), input.device());
buffer = std::move(input);
_conv(buffer, output);
_transpose(output, buffer);
_output_norm(buffer, output);
_transpose(output, buffer);
_gelu(buffer, output);
}

Wav2Vec2LayerNormConvLayer::Wav2Vec2LayerNormConvLayer(const models::Model& model, const std::string& scope)
: _conv(model, scope + "/conv", /*stride=*/2, /*padding=*/0)
Wav2Vec2LayerNormConvLayer::Wav2Vec2LayerNormConvLayer(const models::Model& model,
const std::string& scope,
dim_t stride,
dim_t padding)
: _stride(stride)
, _padding(padding)
, _conv(model, scope + "/conv", _stride, _padding)
, _transpose({0, 2, 1})
, _output_norm(model, scope + "/layer_norm") {
}
Expand Down Expand Up @@ -60,9 +46,11 @@ namespace ctranslate2 {
}

Wav2Vec2Encoder::Wav2Vec2Encoder(const models::Model& model, const std::string& scope)
: _feat_layer0(model, scope + "/feat_layer0")
: _feat_layer0(model, scope + "/feat_layer0", /*stride=*/5, /*padding=*/0)
, _feat_layers(build_layers_list<const Wav2Vec2LayerNormConvLayer>(model,
scope + "/feat_layer"))
scope + "/feat_layer",
/*stride=*/2,
/*padding=*/0))
, _fp_norm(model, scope + "/fp_layer_norm")
, _fp_ff(model, scope + "/fp_projection", nullptr, true)
, _pos_conv_embed(model, scope + "/pos_conv_embed")
Expand Down

0 comments on commit a1c112f

Please sign in to comment.