Skip to content

Commit

Permalink
Directly register finalized model variables
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Jan 27, 2023
1 parent a7bc320 commit 9de7dd6
Show file tree
Hide file tree
Showing 14 changed files with 451 additions and 293 deletions.
9 changes: 9 additions & 0 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ namespace ctranslate2 {
dim_t output_size() const override;
void operator()(const StorageView& input, StorageView& output) const;
void select_weights(const StorageView* index, const StorageView* extra_bias = nullptr);

static void
register_weight(const std::string& name,
std::shared_ptr<StorageView> weight,
models::Model& model,
std::unordered_map<std::string, std::shared_ptr<StorageView>>& variables,
Device target_device,
ComputeType compute_type,
bool allow_packing = false);
private:
bool _packed_weight;
const StorageView& _weight;
Expand Down
29 changes: 14 additions & 15 deletions include/ctranslate2/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ namespace ctranslate2 {
class SequenceToSequenceReplica;
class SequenceGeneratorReplica;

using VariablesCollection = std::unordered_map<std::string, std::shared_ptr<StorageView>>;

// Base class for models.
class Model : public std::enable_shared_from_this<Model> {
public:
Expand Down Expand Up @@ -114,7 +116,13 @@ namespace ctranslate2 {
return static_cast<Enum>(get_attribute_with_default<int32_t>(name, 0));
}

// Registers a model variable.
void register_variable(std::string name, std::shared_ptr<StorageView> variable);

protected:
// Updates a variable name on load, e.g. for backward compatibility.
virtual void update_variable_name(std::string&) const {}

// Returns true if the variable is quantizable and should respect compute_type.
virtual bool is_quantizable(const std::string& variable_name) const;

Expand All @@ -125,35 +133,26 @@ namespace ctranslate2 {
virtual bool is_packable(const std::string& variable_name) const;

// Returns true if the variable can be converted to another type.
virtual bool is_convertible(const StorageView& variable, const std::string& name) const;

// Models can override these methods to execute some transformations if needed
// (e.g. a variable name changed in a newer spec revision).
virtual void register_variable(std::string name, StorageView variable);
virtual void register_variable_alias(std::string alias, const std::string& variable_name);
virtual void remove_variable(const std::string& name);
virtual bool is_convertible(const std::string& name, DataType dtype) const;

// Runs some initialization after the model is loaded.
virtual void initialize(ModelReader&) {}

virtual std::unique_ptr<Model> clone() const = 0;

private:
void process_linear_weights();
void set_compute_type(ComputeType type, Device device, int device_index);
void ensure_dtype(const std::string& name,
StorageView& variable,
const DataType target_dtype);
ComputeType infer_compute_type() const;
void register_variable(std::string name, StorageView variable);
void register_variable_alias(std::string alias, const std::string& variable_name);
void remove_variable(const std::string& name);

private:
Device _device = Device::CPU;
int _device_index = 0;
size_t _binary_version = 0;
size_t _spec_revision = 0;
ComputeType _compute_type = ComputeType::DEFAULT;
ComputeType _effective_compute_type = ComputeType::DEFAULT;
dim_t _preferred_size_multiple = 1;
std::unordered_map<std::string, std::shared_ptr<StorageView>> _variable_index;
VariablesCollection _variable_index;
};

template<>
Expand Down
2 changes: 1 addition & 1 deletion include/ctranslate2/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace ctranslate2 {
std::unique_ptr<SequenceToSequenceReplica> as_sequence_to_sequence() const override;

protected:
void update_variable_name(std::string& variable_name) const override;
bool is_linear_weight(const std::string& variable_name) const override;
bool is_packable(const std::string& variable_name) const override;
void register_variable(std::string name, StorageView variable) override;
void initialize(ModelReader& model_reader) override;
std::unique_ptr<Model> clone() const override;

Expand Down
14 changes: 14 additions & 0 deletions include/ctranslate2/ops/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ namespace ctranslate2 {
const StorageView* a_shift_compensation = nullptr,
const StorageView* bias = nullptr) const;

// Return the packed representation of b, if implemented by the GEMM backend.
static StorageView pack_b_input(const StorageView& b,
const bool transpose,
const dim_t k,
const dim_t n,
const float alpha);

// Return the compensation term when s8s8s32 is implemented with u8s8s32.
static StorageView compensate_u8_a_input(const StorageView& b,
const bool transpose,
const dim_t k,
const dim_t n,
const float alpha);

private:
float _alpha;
float _beta;
Expand Down
2 changes: 2 additions & 0 deletions include/ctranslate2/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ namespace ctranslate2 {

// Gets the weights data type for the given compute type.
DataType compute_type_to_data_type(const ComputeType compute_type);
ComputeType data_type_to_compute_type(const DataType quantizable_type,
const DataType float_type);

// Gets the default floating point type for the given compute type.
DataType get_default_float_type(const ComputeType compute_type);
Expand Down
18 changes: 18 additions & 0 deletions python/tests/test_marian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import pytest
import test_utils

import ctranslate2
Expand All @@ -13,3 +14,20 @@ def test_marian_model_conversion(tmpdir):
translator = ctranslate2.Translator(output_dir)
output = translator.translate_batch([["▁Hello", "▁world", "!"]])
assert output[0].hypotheses[0] == ["▁Hallo", "▁Welt", "!"]


@pytest.mark.parametrize(
"quantization", [None, "int8", "int16", "float16", "int8_float16"]
)
def test_marian_model_quantization(tmpdir, quantization):
model_dir = os.path.join(test_utils.get_data_dir(), "models", "opus-mt-ende")
converter = ctranslate2.converters.OpusMTConverter(model_dir)
output_dir = str(tmpdir.join("ctranslate2_model"))
converter.convert(output_dir, quantization=quantization)

compute_types = ["default"] + list(ctranslate2.get_supported_compute_types("cpu"))

for compute_type in compute_types:
translator = ctranslate2.Translator(output_dir, compute_type=compute_type)
output = translator.translate_batch([["▁Hello", "▁world", "!"]])
assert output[0].hypotheses[0] == ["▁Hallo", "▁Welt", "!"]
4 changes: 2 additions & 2 deletions src/cpu/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ namespace ctranslate2 {
return gemm_s8_backend == cpu::GemmBackend::MKL || gemm_s8_backend == cpu::GemmBackend::DNNL;
}

bool should_pack_gemm_weights() {
bool pack_gemm_weights(ComputeType compute_type) {
static const bool should_pack = read_bool_from_env("CT2_USE_EXPERIMENTAL_PACKED_GEMM");
return should_pack;
return should_pack && get_gemm_backend(compute_type) == GemmBackend::MKL;
}

#ifdef CT2_WITH_RUY
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace ctranslate2 {
GemmBackend get_gemm_backend(ComputeType compute_type);
bool has_gemm_backend(ComputeType compute_type);
bool prefer_u8s8s32_gemm();
bool should_pack_gemm_weights();
bool pack_gemm_weights(ComputeType compute_type);
#ifdef CT2_WITH_RUY
ruy::Context *get_ruy_context();
#endif
Expand Down
56 changes: 53 additions & 3 deletions src/layers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,18 @@ namespace ctranslate2 {
return model.get_variable(scope + "/weight");
}

static inline bool require_compensation(Device device, DataType dtype) {
return device == Device::CPU && dtype == DataType::INT8 && cpu::prefer_u8s8s32_gemm();
}

Dense::Dense(const models::Model& model,
const std::string& scope,
const ops::ActivationType* activation_type)
: _packed_weight(false)
, _weight(get_linear_weight(model, scope, &_packed_weight))
, _bias(model.get_variable_if_exists(scope + "/bias"))
, _qscale(model.get_variable_if_exists(scope + "/weight_scale"))
, _u8_shift_compensation((_weight.device() == Device::CPU
&& _weight.dtype() == DataType::INT8
&& cpu::prefer_u8s8s32_gemm())
, _u8_shift_compensation(require_compensation(_weight.device(), _weight.dtype())
? &model.get_variable(scope + "/weight_compensation")
: nullptr)
, _partial_weight(_weight.device(), _weight.dtype())
Expand Down Expand Up @@ -347,6 +349,54 @@ namespace ctranslate2 {
}
}

void
Dense::register_weight(const std::string& name,
std::shared_ptr<StorageView> weight_ptr,
models::Model& model,
std::unordered_map<std::string, std::shared_ptr<StorageView>>& variables,
Device target_device,
ComputeType compute_type,
bool allow_packing) {
auto& weight = *weight_ptr;

if (weight.device() != target_device)
weight = weight.to(target_device);

const bool transpose = true;
const dim_t k = weight.dim(1);
const dim_t n = weight.dim(0);
const float alpha = 1;

// If the target Gemm implementation prefers the u8s8s32 format, we can shift
// the input of linear layers to the u8 domain and add a compensation term.
// This term only depends on the linear weight, so we can compute it once and
// store it as a model variable.
if (require_compensation(weight.device(), weight.dtype())) {
const std::string suffix = "_compensation";

auto& compensation = variables[suffix];
if (!compensation)
compensation = std::make_shared<StorageView>(
ops::Gemm::compensate_u8_a_input(weight, transpose, k, n, alpha));

model.register_variable(name + suffix, compensation);
}

// If requested, linear weights can be packed for the Gemm call.
if (allow_packing && cpu::pack_gemm_weights(compute_type)) {
const std::string suffix = "_packed";

auto& packed_weight = variables[suffix];
if (!packed_weight)
packed_weight = std::make_shared<StorageView>(
ops::Gemm::pack_b_input(weight, transpose, k, n, alpha));

model.register_variable(name + suffix, packed_weight);
} else {
model.register_variable(name, weight_ptr);
}
}


LayerNorm::LayerNorm(const models::Model& model, const std::string& scope)
: _beta(model.get_variable_if_exists(scope + "/beta"))
Expand Down
Loading

0 comments on commit 9de7dd6

Please sign in to comment.