Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Directly register finalized model variables #1058

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ namespace ctranslate2 {
dim_t output_size() const override;
void operator()(const StorageView& input, StorageView& output) const;
void select_weights(const StorageView* index, const StorageView* extra_bias = nullptr);

static void
register_weight(const std::string& name,
std::shared_ptr<StorageView> weight,
models::Model& model,
std::unordered_map<std::string, std::shared_ptr<StorageView>>& variables,
Device target_device,
ComputeType compute_type,
bool allow_packing = false);
private:
bool _packed_weight;
const StorageView& _weight;
Expand Down
29 changes: 14 additions & 15 deletions include/ctranslate2/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ namespace ctranslate2 {
class SequenceToSequenceReplica;
class SequenceGeneratorReplica;

using VariablesCollection = std::unordered_map<std::string, std::shared_ptr<StorageView>>;

// Base class for models.
class Model : public std::enable_shared_from_this<Model> {
public:
Expand Down Expand Up @@ -114,7 +116,13 @@ namespace ctranslate2 {
return static_cast<Enum>(get_attribute_with_default<int32_t>(name, 0));
}

// Registers a model variable.
void register_variable(std::string name, std::shared_ptr<StorageView> variable);

protected:
// Updates a variable name on load, e.g. for backward compatibility.
virtual void update_variable_name(std::string&) const {}

// Returns true if the variable is quantizable and should respect compute_type.
virtual bool is_quantizable(const std::string& variable_name) const;

Expand All @@ -125,35 +133,26 @@ namespace ctranslate2 {
virtual bool is_packable(const std::string& variable_name) const;

// Returns true if the variable can be converted to another type.
virtual bool is_convertible(const StorageView& variable, const std::string& name) const;

// Models can override these methods to execute some transformations if needed
// (e.g. a variable name changed in a newer spec revision).
virtual void register_variable(std::string name, StorageView variable);
virtual void register_variable_alias(std::string alias, const std::string& variable_name);
virtual void remove_variable(const std::string& name);
virtual bool is_convertible(const std::string& name, DataType dtype) const;

// Runs some initialization after the model is loaded.
virtual void initialize(ModelReader&) {}

virtual std::unique_ptr<Model> clone() const = 0;

private:
void process_linear_weights();
void set_compute_type(ComputeType type, Device device, int device_index);
void ensure_dtype(const std::string& name,
StorageView& variable,
const DataType target_dtype);
ComputeType infer_compute_type() const;
void register_variable(std::string name, StorageView variable);
void register_variable_alias(std::string alias, const std::string& variable_name);
void remove_variable(const std::string& name);

private:
Device _device = Device::CPU;
int _device_index = 0;
size_t _binary_version = 0;
size_t _spec_revision = 0;
ComputeType _compute_type = ComputeType::DEFAULT;
ComputeType _effective_compute_type = ComputeType::DEFAULT;
dim_t _preferred_size_multiple = 1;
std::unordered_map<std::string, std::shared_ptr<StorageView>> _variable_index;
VariablesCollection _variable_index;
};

template<>
Expand Down
2 changes: 1 addition & 1 deletion include/ctranslate2/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace ctranslate2 {
std::unique_ptr<SequenceToSequenceReplica> as_sequence_to_sequence() const override;

protected:
void update_variable_name(std::string& variable_name) const override;
bool is_linear_weight(const std::string& variable_name) const override;
bool is_packable(const std::string& variable_name) const override;
void register_variable(std::string name, StorageView variable) override;
void initialize(ModelReader& model_reader) override;
std::unique_ptr<Model> clone() const override;

Expand Down
14 changes: 14 additions & 0 deletions include/ctranslate2/ops/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ namespace ctranslate2 {
const StorageView* a_shift_compensation = nullptr,
const StorageView* bias = nullptr) const;

// Return the packed representation of b, if implemented by the GEMM backend.
static StorageView pack_b_input(const StorageView& b,
const bool transpose,
const dim_t k,
const dim_t n,
const float alpha);

// Return the compensation term when s8s8s32 is implemented with u8s8s32.
static StorageView compensate_u8_a_input(const StorageView& b,
const bool transpose,
const dim_t k,
const dim_t n,
const float alpha);

private:
float _alpha;
float _beta;
Expand Down
6 changes: 4 additions & 2 deletions include/ctranslate2/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ namespace ctranslate2 {
const int device_index = 0,
const bool enable_fallback = false);

// Gets the weights data type for the given compute type.
DataType compute_type_to_data_type(const ComputeType compute_type);
// Gets the weight type and the float type for the given compute type.
std::pair<DataType, DataType> compute_type_to_data_type(const ComputeType compute_type);
ComputeType data_type_to_compute_type(const DataType weight_type,
const DataType float_type);

// Gets the default floating point type for the given compute type.
DataType get_default_float_type(const ComputeType compute_type);
Expand Down
18 changes: 18 additions & 0 deletions python/tests/test_marian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import pytest
import test_utils

import ctranslate2
Expand All @@ -13,3 +14,20 @@ def test_marian_model_conversion(tmpdir):
translator = ctranslate2.Translator(output_dir)
output = translator.translate_batch([["▁Hello", "▁world", "!"]])
assert output[0].hypotheses[0] == ["▁Hallo", "▁Welt", "!"]


@pytest.mark.parametrize(
"quantization", [None, "int8", "int16", "float16", "int8_float16"]
)
def test_marian_model_quantization(tmpdir, quantization):
model_dir = os.path.join(test_utils.get_data_dir(), "models", "opus-mt-ende")
converter = ctranslate2.converters.OpusMTConverter(model_dir)
output_dir = str(tmpdir.join("ctranslate2_model"))
converter.convert(output_dir, quantization=quantization)

compute_types = ["default"] + list(ctranslate2.get_supported_compute_types("cpu"))

for compute_type in compute_types:
translator = ctranslate2.Translator(output_dir, compute_type=compute_type)
output = translator.translate_batch([["▁Hello", "▁world", "!"]])
assert output[0].hypotheses[0] == ["▁Hallo", "▁Welt", "!"]
4 changes: 2 additions & 2 deletions src/cpu/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ namespace ctranslate2 {
return gemm_s8_backend == cpu::GemmBackend::MKL || gemm_s8_backend == cpu::GemmBackend::DNNL;
}

bool should_pack_gemm_weights() {
bool pack_gemm_weights(ComputeType compute_type) {
static const bool should_pack = read_bool_from_env("CT2_USE_EXPERIMENTAL_PACKED_GEMM");
return should_pack;
return should_pack && get_gemm_backend(compute_type) == GemmBackend::MKL;
}

#ifdef CT2_WITH_RUY
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace ctranslate2 {
GemmBackend get_gemm_backend(ComputeType compute_type);
bool has_gemm_backend(ComputeType compute_type);
bool prefer_u8s8s32_gemm();
bool should_pack_gemm_weights();
bool pack_gemm_weights(ComputeType compute_type);
#ifdef CT2_WITH_RUY
ruy::Context *get_ruy_context();
#endif
Expand Down
56 changes: 53 additions & 3 deletions src/layers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,18 @@ namespace ctranslate2 {
return model.get_variable(scope + "/weight");
}

static inline bool require_compensation(Device device, DataType dtype) {
return device == Device::CPU && dtype == DataType::INT8 && cpu::prefer_u8s8s32_gemm();
}

Dense::Dense(const models::Model& model,
const std::string& scope,
const ops::ActivationType* activation_type)
: _packed_weight(false)
, _weight(get_linear_weight(model, scope, &_packed_weight))
, _bias(model.get_variable_if_exists(scope + "/bias"))
, _qscale(model.get_variable_if_exists(scope + "/weight_scale"))
, _u8_shift_compensation((_weight.device() == Device::CPU
&& _weight.dtype() == DataType::INT8
&& cpu::prefer_u8s8s32_gemm())
, _u8_shift_compensation(require_compensation(_weight.device(), _weight.dtype())
? &model.get_variable(scope + "/weight_compensation")
: nullptr)
, _partial_weight(_weight.device(), _weight.dtype())
Expand Down Expand Up @@ -347,6 +349,54 @@ namespace ctranslate2 {
}
}

void
Dense::register_weight(const std::string& name,
std::shared_ptr<StorageView> weight_ptr,
models::Model& model,
std::unordered_map<std::string, std::shared_ptr<StorageView>>& variables,
Device target_device,
ComputeType compute_type,
bool allow_packing) {
auto& weight = *weight_ptr;

if (weight.device() != target_device)
weight = weight.to(target_device);

const bool transpose = true;
const dim_t k = weight.dim(1);
const dim_t n = weight.dim(0);
const float alpha = 1;

// If the target Gemm implementation prefers the u8s8s32 format, we can shift
// the input of linear layers to the u8 domain and add a compensation term.
// This term only depends on the linear weight, so we can compute it once and
// store it as a model variable.
if (require_compensation(weight.device(), weight.dtype())) {
const std::string suffix = "_compensation";

auto& compensation = variables[suffix];
if (!compensation)
compensation = std::make_shared<StorageView>(
ops::Gemm::compensate_u8_a_input(weight, transpose, k, n, alpha));

model.register_variable(name + suffix, compensation);
}

// If requested, linear weights can be packed for the Gemm call.
if (allow_packing && cpu::pack_gemm_weights(compute_type)) {
const std::string suffix = "_packed";

auto& packed_weight = variables[suffix];
if (!packed_weight)
packed_weight = std::make_shared<StorageView>(
ops::Gemm::pack_b_input(weight, transpose, k, n, alpha));

model.register_variable(name + suffix, packed_weight);
} else {
model.register_variable(name, weight_ptr);
}
}


LayerNorm::LayerNorm(const models::Model& model, const std::string& scope)
: _beta(model.get_variable_if_exists(scope + "/beta"))
Expand Down
Loading