Skip to content

Commit

Permalink
Compute queries scale in the layer constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Jul 29, 2020
1 parent 836ef00 commit 781a46a
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 4 deletions.
2 changes: 2 additions & 0 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace ctranslate2 {
bool self_attention,
LayerNormStrategy layer_norm_strategy = LayerNormStrategy::Input);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& queries,
const StorageView* memory,
const StorageView* memory_lengths,
Expand All @@ -41,6 +42,7 @@ namespace ctranslate2 {
const StorageView* _relative_position_keys;
const StorageView* _relative_position_values;
const dim_t _maximum_relative_position;
const float _queries_scale;
const ops::Transpose _transpose_op;

void split_heads(StorageView& x, StorageView& y) const;
Expand Down
4 changes: 4 additions & 0 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ namespace ctranslate2 {
class Layer {
public:
virtual DataType output_type() const = 0;
virtual dim_t output_size() const = 0;
};

class Embeddings : public Layer
{
public:
Embeddings(const models::Model& model, const std::string& scope);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& ids, StorageView& output) const;
private:
const ops::Gather _gather_op;
Expand All @@ -29,6 +31,7 @@ namespace ctranslate2 {
public:
Dense(const models::Model& model, const std::string& scope);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& input, StorageView& output) const;
void mask_weights(const StorageView& index);
void reset_mask();
Expand All @@ -52,6 +55,7 @@ namespace ctranslate2 {
public:
LayerNorm(const models::Model& model, const std::string& scope);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& input, StorageView& output) const;
private:
const ops::LayerNorm _norm_op;
Expand Down
2 changes: 2 additions & 0 deletions include/ctranslate2/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ namespace ctranslate2 {
public:
TransformerEncoder(const TransformerModel& model, const std::string& scope);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& ids,
const StorageView& lengths,
StorageView& output) override;
Expand All @@ -113,6 +114,7 @@ namespace ctranslate2 {
const std::string& scope,
const bool with_encoder_attention = true);
DataType output_type() const override;
dim_t output_size() const override;
void set_vocabulary_mask(const StorageView& ids) override;
void reset_vocabulary_mask() override;
layers::DecoderState initial_state() const override;
Expand Down
10 changes: 6 additions & 4 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,18 @@ namespace ctranslate2 {
, _relative_position_values(model.get_variable_if_exists(scope + "/relative_position_values"))
, _maximum_relative_position(_relative_position_keys
? (_relative_position_keys->dim(0) - 1) / 2 : 0)
, _queries_scale(1.f / std::sqrt(static_cast<float>(_layer_norm.output_size() / num_heads)))
, _transpose_op({0, 2, 1, 3}) {
}

DataType MultiHeadAttention::output_type() const {
return _layer_norm.output_type();
}

dim_t MultiHeadAttention::output_size() const {
return _layer_norm.output_size();
}

void MultiHeadAttention::operator()(const StorageView& queries,
const StorageView* memory,
const StorageView* memory_lengths,
Expand Down Expand Up @@ -230,9 +235,6 @@ namespace ctranslate2 {
}
}

const dim_t dk = queries.dim(-1) / _num_heads;
const float queries_scale = 1.0 / sqrt(dk);

StorageView& context = queries_proj; // Reuse storage.
dot_product_attention(split_queries,
split_keys,
Expand All @@ -243,7 +245,7 @@ namespace ctranslate2 {
_maximum_relative_position,
context,
attention,
queries_scale,
_queries_scale,
bool(cached_keys));

StorageView& combined = values_proj; // Reuse storage.
Expand Down
12 changes: 12 additions & 0 deletions src/layers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ namespace ctranslate2 {
return _embeddings.dtype() == DataType::FLOAT16 ? DataType::FLOAT16 : DataType::FLOAT;
}

dim_t Embeddings::output_size() const {
return _embeddings.dim(1);
}

void Embeddings::operator()(const StorageView& ids,
StorageView& output) const {
PROFILE("Embeddings");
Expand Down Expand Up @@ -88,6 +92,10 @@ namespace ctranslate2 {
return _weight.dtype() == DataType::FLOAT16 ? DataType::FLOAT16 : DataType::FLOAT;
}

dim_t Dense::output_size() const {
return _weight.dim(0);
}

void Dense::mask_weights(const StorageView& index) {
if (_packed_weight)
throw std::runtime_error("Can't mask pre-packed weight");
Expand Down Expand Up @@ -153,6 +161,10 @@ namespace ctranslate2 {
return _beta.dtype();
}

dim_t LayerNorm::output_size() const {
return _beta.size();
}

void LayerNorm::operator()(const StorageView& input, StorageView& output) const {
_norm_op(_beta, _gamma, input, output);
}
Expand Down
8 changes: 8 additions & 0 deletions src/models/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ namespace ctranslate2 {
return _output_norm.output_type();
}

dim_t TransformerEncoder::output_size() const {
return _output_norm.output_size();
}

void TransformerEncoder::operator()(const StorageView& ids,
const StorageView& lengths,
StorageView& output) {
Expand Down Expand Up @@ -314,6 +318,10 @@ namespace ctranslate2 {
return _proj.output_type();
}

dim_t TransformerDecoder::output_size() const {
return _proj.output_size();
}

void TransformerDecoder::set_vocabulary_mask(const StorageView& ids) {
_proj.mask_weights(ids);
}
Expand Down

0 comments on commit 781a46a

Please sign in to comment.