Skip to content

Commit

Permalink
Create ctranslate2 Moonshine implementation.
Browse files Browse the repository at this point in the history
Adds the following:
- c++ moonshine model
- pybind for python moonshine model
- moonshine model spec
- safetensor moonshine model converter
- support for GroupNorm-style weights for LayerNorm
- support for multi-axis cuda layernorm
  • Loading branch information
njeffrie committed Oct 25, 2024
1 parent 383d063 commit 6373848
Show file tree
Hide file tree
Showing 18 changed files with 871 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ set(SOURCES
src/layers/wav2vec2.cc
src/layers/wav2vec2bert.cc
src/layers/whisper.cc
src/layers/moonshine.cc
src/logging.cc
src/models/language_model.cc
src/models/model.cc
Expand All @@ -139,6 +140,7 @@ set(SOURCES
src/models/wav2vec2.cc
src/models/wav2vec2bert.cc
src/models/whisper.cc
src/models/moonshine.cc
src/ops/activation.cc
src/ops/add.cc
src/ops/alibi_add.cc
Expand Down
75 changes: 75 additions & 0 deletions include/ctranslate2/layers/moonshine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include "ctranslate2/layers/transformer.h"

namespace ctranslate2 {
namespace layers {

class MoonshinePreprocessor : public Layer {
public:
MoonshinePreprocessor(const models::Model& model, const std::string& scope);

void operator()(const StorageView& features, StorageView& output);

DataType output_type() const override {
return _conv3.output_type();
}

dim_t output_size() const override {
return _conv3.output_size();
}

dim_t input_size() const {
return _conv1.input_size();
}
private:
const Conv1D _conv1;
const ops::Tanh _tanh;
const LayerNorm _norm;
const Conv1D _conv2;
const ops::GELU _gelu1;
const Conv1D _conv3;
const ops::GELU _gelu2;
const ops::Transpose _transpose;
};


class MoonshineEncoder : public Layer {
public:
MoonshineEncoder(const models::Model& model, const std::string& scope);

void operator()(const StorageView& features, StorageView& output);

DataType output_type() const override {
return _output_norm.output_type();
}

dim_t output_size() const override {
return _output_norm.output_size();
}

bool is_encoded(const StorageView& features) const {
// Input features shape: [batch_size, input_size, input_time]
// Encoder output shape: [batch_size, input_time // 2, output_size]
//
// input_time is variable so we check that dimension 1 is different than its original value.

return (features.rank() == 3
&& features.dim(2) == output_size()
&& features.dim(1) != 1);
}

private:
const dim_t _num_heads;
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
const LayerNorm _output_norm;
};

class MoonshineDecoder : public TransformerDecoder {
public:
using TransformerDecoder::TransformerDecoder;

bool return_normalized_attention() const override {
return false;
}
};
}
}
134 changes: 134 additions & 0 deletions include/ctranslate2/models/moonshine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#pragma once

#include "ctranslate2/generation.h"
#include "ctranslate2/layers/moonshine.h"
#include "ctranslate2/models/model.h"
#include "ctranslate2/replica_pool.h"

namespace ctranslate2 {
namespace models {

struct MoonshineOptions {
// Beam size to use for beam search (set 1 to run greedy search).
size_t beam_size = 5;

// Beam search patience factor, as described in https://arxiv.org/abs/2204.05424.
// The decoding will continue until beam_size*patience hypotheses are finished.
float patience = 1;

// Exponential penalty applied to the length during beam search.
float length_penalty = 1;

// Penalty applied to the score of previously generated tokens, as described in
// https://arxiv.org/abs/1909.05858 (set > 1 to penalize).
float repetition_penalty = 1;

// Prevent repetitions of ngrams with this size (set 0 to disable).
size_t no_repeat_ngram_size = 0;

// Maximum generation length.
size_t max_length = 448;

// Randomly sample from the top K candidates (set 0 to sample from the full distribution).
size_t sampling_topk = 1;

// High temperatures increase randomness.
float sampling_temperature = 1;

// Number of hypotheses to include in the result.
size_t num_hypotheses = 1;

// Include scores in the result.
bool return_scores = false;

// Suppress blank outputs at the beginning of the sampling.
bool suppress_blank = true;

// List of token IDs to suppress.
// -1 will suppress a default set of symbols as defined in the model config.json file.
std::vector<int> suppress_tokens = {-1};
};

struct MoonshineGenerationResult {
std::vector<std::vector<std::string>> sequences;
std::vector<std::vector<size_t>> sequences_ids;
std::vector<float> scores;

size_t num_sequences() const {
return sequences.size();
}

bool has_scores() const {
return !scores.empty();
}
};

class MoonshineModel : public Model {
public:
const Vocabulary& get_vocabulary() const;

size_t current_spec_revision() const override;
bool is_quantizable(const std::string& variable_name) const override;
bool is_linear_weight(const std::string& variable_name) const override;
std::unique_ptr<Model> clone() const override;

bool use_global_int16_scale() const override {
return false;
}

protected:
void initialize(ModelReader& model_reader) override;

private:
std::shared_ptr<const Vocabulary> _vocabulary;
};

class MoonshineReplica : public ModelReplica {
public:
static std::unique_ptr<MoonshineReplica> create_from_model(const Model& model);

MoonshineReplica(const std::shared_ptr<const MoonshineModel>& model);

StorageView encode(StorageView features, const bool to_cpu);

std::vector<MoonshineGenerationResult>
generate(StorageView features,
const std::vector<std::vector<std::string>>& prompts,
const MoonshineOptions& options);

std::vector<MoonshineGenerationResult>
generate(StorageView features,
const std::vector<std::vector<size_t>>& prompts,
const MoonshineOptions& options);

private:
const std::shared_ptr<const MoonshineModel> _model;
const std::unique_ptr<layers::MoonshinePreprocessor> _preprocessor;
const std::unique_ptr<layers::MoonshineEncoder> _encoder;
const std::unique_ptr<layers::MoonshineDecoder> _decoder;

size_t _sot_id;
size_t _eot_id;

StorageView maybe_encode(StorageView features);
};

class Moonshine : public ReplicaPool<MoonshineReplica> {
public:
using ReplicaPool::ReplicaPool;

std::future<StorageView> encode(const StorageView& features, const bool to_cpu);

std::vector<std::future<MoonshineGenerationResult>>
generate(const StorageView& features,
std::vector<std::vector<std::string>> prompts,
MoonshineOptions options = {});

std::vector<std::future<MoonshineGenerationResult>>
generate(const StorageView& features,
std::vector<std::vector<size_t>> prompts,
MoonshineOptions options = {});
};

}
}
4 changes: 3 additions & 1 deletion include/ctranslate2/ops/layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace ctranslate2 {

class LayerNorm : public TernaryOp {
public:
LayerNorm(const dim_t axis = -1, const float epsilon = 1e-5);
LayerNorm(const dim_t axis = -1, const float epsilon = 1e-5, const bool multi_axis=false);

using TernaryOp::operator();
void operator()(const StorageView& beta,
Expand All @@ -32,10 +32,12 @@ namespace ctranslate2 {
const dim_t outer_size,
const dim_t axis_size,
const dim_t inner_size,
const bool multi_axis,
StorageView& output) const;

const dim_t _axis;
const float _epsilon;
const bool _multi_axis;
};

}
Expand Down
1 change: 1 addition & 0 deletions python/cpp/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,5 @@ PYBIND11_MODULE(_ext, m)
ctranslate2::python::register_wav2vec2(m);
ctranslate2::python::register_wav2vec2bert(m);
ctranslate2::python::register_mpi(m);
ctranslate2::python::register_moonshine(m);
}
1 change: 1 addition & 0 deletions python/cpp/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace ctranslate2 {
void register_wav2vec2(py::module& m);
void register_wav2vec2bert(py::module& m);
void register_mpi(py::module& m);
void register_moonshine(py::module& m);

}
}
Loading

0 comments on commit 6373848

Please sign in to comment.