From 9de7dd664c6504895709ee184d9834c380c8a183 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Wed, 25 Jan 2023 12:24:47 +0100 Subject: [PATCH] Directly register finalized model variables --- include/ctranslate2/layers/common.h | 9 + include/ctranslate2/models/model.h | 29 +- include/ctranslate2/models/transformer.h | 2 +- include/ctranslate2/ops/gemm.h | 14 + include/ctranslate2/types.h | 2 + python/tests/test_marian.py | 18 + src/cpu/backend.cc | 4 +- src/cpu/backend.h | 2 +- src/layers/common.cc | 56 ++- src/models/model.cc | 456 +++++++++++------------ src/models/transformer.cc | 45 +-- src/ops/gemm.cc | 79 ++++ src/types.cc | 14 + src/utils.cc | 14 +- 14 files changed, 451 insertions(+), 293 deletions(-) diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index fdb3a08e4..8c7e18f39 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -131,6 +131,15 @@ namespace ctranslate2 { dim_t output_size() const override; void operator()(const StorageView& input, StorageView& output) const; void select_weights(const StorageView* index, const StorageView* extra_bias = nullptr); + + static void + register_weight(const std::string& name, + std::shared_ptr weight, + models::Model& model, + std::unordered_map>& variables, + Device target_device, + ComputeType compute_type, + bool allow_packing = false); private: bool _packed_weight; const StorageView& _weight; diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 14ec7a75b..3dda45674 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -19,6 +19,8 @@ namespace ctranslate2 { class SequenceToSequenceReplica; class SequenceGeneratorReplica; + using VariablesCollection = std::unordered_map>; + // Base class for models. class Model : public std::enable_shared_from_this { public: @@ -114,7 +116,13 @@ namespace ctranslate2 { return static_cast(get_attribute_with_default(name, 0)); } + // Registers a model variable. + void register_variable(std::string name, std::shared_ptr variable); + protected: + // Updates a variable name on load, e.g. for backward compatibility. + virtual void update_variable_name(std::string&) const {} + // Returns true if the variable is quantizable and should respect compute_type. virtual bool is_quantizable(const std::string& variable_name) const; @@ -125,27 +133,18 @@ namespace ctranslate2 { virtual bool is_packable(const std::string& variable_name) const; // Returns true if the variable can be converted to another type. - virtual bool is_convertible(const StorageView& variable, const std::string& name) const; - - // Models can override these methods to execute some transformations if needed - // (e.g. a variable name changed in a newer spec revision). - virtual void register_variable(std::string name, StorageView variable); - virtual void register_variable_alias(std::string alias, const std::string& variable_name); - virtual void remove_variable(const std::string& name); + virtual bool is_convertible(const std::string& name, DataType dtype) const; // Runs some initialization after the model is loaded. virtual void initialize(ModelReader&) {} virtual std::unique_ptr clone() const = 0; - private: - void process_linear_weights(); - void set_compute_type(ComputeType type, Device device, int device_index); - void ensure_dtype(const std::string& name, - StorageView& variable, - const DataType target_dtype); - ComputeType infer_compute_type() const; + void register_variable(std::string name, StorageView variable); + void register_variable_alias(std::string alias, const std::string& variable_name); + void remove_variable(const std::string& name); + private: Device _device = Device::CPU; int _device_index = 0; size_t _binary_version = 0; @@ -153,7 +152,7 @@ namespace ctranslate2 { ComputeType _compute_type = ComputeType::DEFAULT; ComputeType _effective_compute_type = ComputeType::DEFAULT; dim_t _preferred_size_multiple = 1; - std::unordered_map> _variable_index; + VariablesCollection _variable_index; }; template<> diff --git a/include/ctranslate2/models/transformer.h b/include/ctranslate2/models/transformer.h index 2ee4cc5b5..019535c36 100644 --- a/include/ctranslate2/models/transformer.h +++ b/include/ctranslate2/models/transformer.h @@ -15,9 +15,9 @@ namespace ctranslate2 { std::unique_ptr as_sequence_to_sequence() const override; protected: + void update_variable_name(std::string& variable_name) const override; bool is_linear_weight(const std::string& variable_name) const override; bool is_packable(const std::string& variable_name) const override; - void register_variable(std::string name, StorageView variable) override; void initialize(ModelReader& model_reader) override; std::unique_ptr clone() const override; diff --git a/include/ctranslate2/ops/gemm.h b/include/ctranslate2/ops/gemm.h index ea489c99f..af56a6235 100644 --- a/include/ctranslate2/ops/gemm.h +++ b/include/ctranslate2/ops/gemm.h @@ -26,6 +26,20 @@ namespace ctranslate2 { const StorageView* a_shift_compensation = nullptr, const StorageView* bias = nullptr) const; + // Return the packed representation of b, if implemented by the GEMM backend. + static StorageView pack_b_input(const StorageView& b, + const bool transpose, + const dim_t k, + const dim_t n, + const float alpha); + + // Return the compensation term when s8s8s32 is implemented with u8s8s32. + static StorageView compensate_u8_a_input(const StorageView& b, + const bool transpose, + const dim_t k, + const dim_t n, + const float alpha); + private: float _alpha; float _beta; diff --git a/include/ctranslate2/types.h b/include/ctranslate2/types.h index 906b3b362..0616f5d27 100644 --- a/include/ctranslate2/types.h +++ b/include/ctranslate2/types.h @@ -49,6 +49,8 @@ namespace ctranslate2 { // Gets the weights data type for the given compute type. DataType compute_type_to_data_type(const ComputeType compute_type); + ComputeType data_type_to_compute_type(const DataType quantizable_type, + const DataType float_type); // Gets the default floating point type for the given compute type. DataType get_default_float_type(const ComputeType compute_type); diff --git a/python/tests/test_marian.py b/python/tests/test_marian.py index aea6d01fa..d8f41b15d 100644 --- a/python/tests/test_marian.py +++ b/python/tests/test_marian.py @@ -1,5 +1,6 @@ import os +import pytest import test_utils import ctranslate2 @@ -13,3 +14,20 @@ def test_marian_model_conversion(tmpdir): translator = ctranslate2.Translator(output_dir) output = translator.translate_batch([["▁Hello", "▁world", "!"]]) assert output[0].hypotheses[0] == ["▁Hallo", "▁Welt", "!"] + + +@pytest.mark.parametrize( + "quantization", [None, "int8", "int16", "float16", "int8_float16"] +) +def test_marian_model_quantization(tmpdir, quantization): + model_dir = os.path.join(test_utils.get_data_dir(), "models", "opus-mt-ende") + converter = ctranslate2.converters.OpusMTConverter(model_dir) + output_dir = str(tmpdir.join("ctranslate2_model")) + converter.convert(output_dir, quantization=quantization) + + compute_types = ["default"] + list(ctranslate2.get_supported_compute_types("cpu")) + + for compute_type in compute_types: + translator = ctranslate2.Translator(output_dir, compute_type=compute_type) + output = translator.translate_batch([["▁Hello", "▁world", "!"]]) + assert output[0].hypotheses[0] == ["▁Hallo", "▁Welt", "!"] diff --git a/src/cpu/backend.cc b/src/cpu/backend.cc index 716a013f9..b53b0b468 100644 --- a/src/cpu/backend.cc +++ b/src/cpu/backend.cc @@ -93,9 +93,9 @@ namespace ctranslate2 { return gemm_s8_backend == cpu::GemmBackend::MKL || gemm_s8_backend == cpu::GemmBackend::DNNL; } - bool should_pack_gemm_weights() { + bool pack_gemm_weights(ComputeType compute_type) { static const bool should_pack = read_bool_from_env("CT2_USE_EXPERIMENTAL_PACKED_GEMM"); - return should_pack; + return should_pack && get_gemm_backend(compute_type) == GemmBackend::MKL; } #ifdef CT2_WITH_RUY diff --git a/src/cpu/backend.h b/src/cpu/backend.h index d5572614a..9c39769ba 100644 --- a/src/cpu/backend.h +++ b/src/cpu/backend.h @@ -25,7 +25,7 @@ namespace ctranslate2 { GemmBackend get_gemm_backend(ComputeType compute_type); bool has_gemm_backend(ComputeType compute_type); bool prefer_u8s8s32_gemm(); - bool should_pack_gemm_weights(); + bool pack_gemm_weights(ComputeType compute_type); #ifdef CT2_WITH_RUY ruy::Context *get_ruy_context(); #endif diff --git a/src/layers/common.cc b/src/layers/common.cc index 9dfb6339e..f0d5805f1 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -254,6 +254,10 @@ namespace ctranslate2 { return model.get_variable(scope + "/weight"); } + static inline bool require_compensation(Device device, DataType dtype) { + return device == Device::CPU && dtype == DataType::INT8 && cpu::prefer_u8s8s32_gemm(); + } + Dense::Dense(const models::Model& model, const std::string& scope, const ops::ActivationType* activation_type) @@ -261,9 +265,7 @@ namespace ctranslate2 { , _weight(get_linear_weight(model, scope, &_packed_weight)) , _bias(model.get_variable_if_exists(scope + "/bias")) , _qscale(model.get_variable_if_exists(scope + "/weight_scale")) - , _u8_shift_compensation((_weight.device() == Device::CPU - && _weight.dtype() == DataType::INT8 - && cpu::prefer_u8s8s32_gemm()) + , _u8_shift_compensation(require_compensation(_weight.device(), _weight.dtype()) ? &model.get_variable(scope + "/weight_compensation") : nullptr) , _partial_weight(_weight.device(), _weight.dtype()) @@ -347,6 +349,54 @@ namespace ctranslate2 { } } + void + Dense::register_weight(const std::string& name, + std::shared_ptr weight_ptr, + models::Model& model, + std::unordered_map>& variables, + Device target_device, + ComputeType compute_type, + bool allow_packing) { + auto& weight = *weight_ptr; + + if (weight.device() != target_device) + weight = weight.to(target_device); + + const bool transpose = true; + const dim_t k = weight.dim(1); + const dim_t n = weight.dim(0); + const float alpha = 1; + + // If the target Gemm implementation prefers the u8s8s32 format, we can shift + // the input of linear layers to the u8 domain and add a compensation term. + // This term only depends on the linear weight, so we can compute it once and + // store it as a model variable. + if (require_compensation(weight.device(), weight.dtype())) { + const std::string suffix = "_compensation"; + + auto& compensation = variables[suffix]; + if (!compensation) + compensation = std::make_shared( + ops::Gemm::compensate_u8_a_input(weight, transpose, k, n, alpha)); + + model.register_variable(name + suffix, compensation); + } + + // If requested, linear weights can be packed for the Gemm call. + if (allow_packing && cpu::pack_gemm_weights(compute_type)) { + const std::string suffix = "_packed"; + + auto& packed_weight = variables[suffix]; + if (!packed_weight) + packed_weight = std::make_shared( + ops::Gemm::pack_b_input(weight, transpose, k, n, alpha)); + + model.register_variable(name + suffix, packed_weight); + } else { + model.register_variable(name, weight_ptr); + } + } + LayerNorm::LayerNorm(const models::Model& model, const std::string& scope) : _beta(model.get_variable_if_exists(scope + "/beta")) diff --git a/src/models/model.cc b/src/models/model.cc index 836833c50..051cf7492 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -2,6 +2,7 @@ #include +#include "ctranslate2/layers/common.h" #include "ctranslate2/models/model_factory.h" #include "ctranslate2/ops/ops.h" #include "ctranslate2/utils.h" @@ -10,8 +11,6 @@ # include "cuda/utils.h" #endif -#include "cpu/backend.h" - namespace ctranslate2 { namespace models { @@ -64,7 +63,6 @@ namespace ctranslate2 { return str; } - template static void move_variables_to_device(VariablesCollection& variables, const Device device) { for (auto& pair : variables) { StorageView& variable = *pair.second; @@ -74,7 +72,6 @@ namespace ctranslate2 { } } - template static void move_variables(VariablesCollection& variables, const Device src_device, const int src_device_index, const Device dst_device, const int dst_device_index) { @@ -121,38 +118,6 @@ namespace ctranslate2 { return copy; } - template - static void pack_weight(const StorageView& weight, - const bool transpose, - const dim_t k, - const dim_t n, - const float alpha, - StorageView& packed_weight) { - const T* src = weight.data(); - const dim_t pack_bytes = primitives::gemm_pack_b(src, - transpose, - k, n, - alpha); - - if (pack_bytes == 0) // Packed Gemm is not supported. - return; - - const dim_t pack_size = pack_bytes / sizeof (T); - const dim_t weight_size = weight.size(); - - // We want the packed storage to have the same shape as the original weight - // so that operators can query its shape, but also have enough space to store - // the packed data. - packed_weight.reserve(std::max(weight_size, pack_size)); - packed_weight.resize_as(weight); - - primitives::gemm_pack_b(src, - transpose, - k, n, - alpha, - packed_weight.data()); - } - std::unique_ptr Model::as_sequence_to_sequence() const { throw std::runtime_error("This model cannot be used as a sequence-to-sequence model"); @@ -179,38 +144,6 @@ namespace ctranslate2 { _device_index = index; } - void Model::set_compute_type(ComputeType type, Device device, int device_index) { - if (_device != Device::CPU) - throw std::runtime_error("set_compute_type expects the variables to be on CPU"); - - _compute_type = type; - _effective_compute_type = resolve_compute_type(type, - infer_compute_type(), - device, - device_index); - _preferred_size_multiple = get_preferred_size_multiple(_effective_compute_type, - device, - device_index); - - const DataType target_dtype = compute_type_to_data_type(_effective_compute_type); - const DataType float_dtype = get_default_float_type(_effective_compute_type); - - const auto variable_index = _variable_index; - for (auto& variable_pair : variable_index) { - const auto& name = variable_pair.first; - auto& variable = *variable_pair.second; - - // Convert "weight" variables to the expected compute type. - // Other float variables (e.g. biases) may be converted from or to float16. - if (is_quantizable(name)) - ensure_dtype(name, variable, target_dtype); - else if (is_convertible(variable, name) - && is_float_type(variable.dtype()) - && variable.dtype() != float_dtype) - variable = variable.to(float_dtype); - } - } - const StorageView* Model::get_variable_if_exists(const std::string& name) const { auto it = _variable_index.find(name); if (it == _variable_index.end()) @@ -248,8 +181,12 @@ namespace ctranslate2 { return get_attribute_with_default(name, static_cast(default_value)); } + void Model::register_variable(std::string name, std::shared_ptr variable) { + _variable_index.emplace(std::move(name), std::move(variable)); + } + void Model::register_variable(std::string name, StorageView variable) { - _variable_index.emplace(std::move(name), std::make_shared(std::move(variable))); + register_variable(std::move(name), std::make_shared(std::move(variable))); } void Model::register_variable_alias(std::string alias, const std::string& variable_name) { @@ -275,31 +212,27 @@ namespace ctranslate2 { return is_linear_weight(variable_name); } - bool Model::is_convertible(const StorageView& variable, const std::string& name) const { - return !variable.is_scalar() && name.find("_scale") == std::string::npos; + bool Model::is_convertible(const std::string& name, DataType dtype) const { + return is_float_type(dtype) && name.find("_scale") == std::string::npos; } - void Model::ensure_dtype(const std::string& name, - StorageView& variable, - const DataType target_dtype) { + static void + convert_weight(const std::string& name, + StorageView& variable, + StorageView& scale, + const DataType target_dtype, + const bool round_before_cast) { const bool is_int8 = variable.dtype() == DataType::INT8; const bool is_int16 = variable.dtype() == DataType::INT16; const bool is_float = variable.dtype() == DataType::FLOAT; const bool is_float16 = variable.dtype() == DataType::FLOAT16; - const std::string scale_name = name + "_scale"; - const StorageView* saved_scale = nullptr; - if (is_int8 || is_int16) { - // Check that the quantization scale of the variable exists. - saved_scale = get_variable_if_exists(scale_name); - if (!saved_scale) { - if (is_int16) { - // Backward compatibility with int16 models without a saved scale. - register_variable(scale_name, StorageView(ops::Quantize::global_int16_scale)); - saved_scale = get_variable_if_exists(scale_name); - } else { - throw std::runtime_error("variable " + scale_name + " not found"); - } + if (!scale) { + if (is_int16) { + // Backward compatibility with int16 models without a saved scale. + scale = StorageView(ops::Quantize::global_int16_scale); + } else if (is_int8) { + throw std::runtime_error("Missing quantization scale for int8 variable " + name); } } @@ -309,7 +242,7 @@ namespace ctranslate2 { // Use the same quantization logic as in model_spec.py. const ops::Quantize quantize_op(/*int16_scale_type=*/ops::Quantize::ScaleType::PER_LAYER, /*shift_to_uint8=*/false, - /*round_before_cast=*/round_before_cast_in_quantization()); + /*round_before_cast=*/round_before_cast); const ops::Dequantize dequantize_op{}; StorageView target_variable(target_dtype); @@ -321,8 +254,8 @@ namespace ctranslate2 { } else { // Dequantize int8 or int16 back to float32. StorageView dequantized; - dequantize_op(variable, *saved_scale, dequantized); - remove_variable(scale_name); // The scale is no longer needed. + dequantize_op(variable, scale, dequantized); + scale.clear(); // The scale is no longer needed. if (target_dtype == DataType::FLOAT16) { target_variable = dequantized.to_float16(); } else { @@ -332,113 +265,22 @@ namespace ctranslate2 { } else if (is_float || is_float16) { // Quantize float32 to int8 or int16. - StorageView scale; if (is_float16) { quantize_op(variable.to_float(), target_variable, scale); } else { quantize_op(variable, target_variable, scale); } - register_variable(scale_name, std::move(scale)); } else { // Convert int8 -> float32 -> int16 or int16 -> float32 -> int8. StorageView tmp_variable; - StorageView new_scale; - dequantize_op(variable, *saved_scale, tmp_variable); - quantize_op(tmp_variable, target_variable, new_scale); - remove_variable(scale_name); - register_variable(scale_name, std::move(new_scale)); + dequantize_op(variable, scale, tmp_variable); + quantize_op(tmp_variable, target_variable, scale); } variable = std::move(target_variable); } - ComputeType Model::infer_compute_type() const { - DataType weight_type = DataType::FLOAT; - DataType other_type = DataType::FLOAT; - - for (const auto& variable_pair : _variable_index) { - const std::string& name = variable_pair.first; - const StorageView& variable = *variable_pair.second; - if (is_quantizable(name)) { - weight_type = variable.dtype(); - } else if (is_convertible(variable, name)) { - other_type = variable.dtype(); - } - } - - switch (weight_type) { - case DataType::INT8: - return other_type == DataType::FLOAT16 ? ComputeType::INT8_FLOAT16 : ComputeType::INT8; - case DataType::INT16: - return ComputeType::INT16; - case DataType::FLOAT16: - return ComputeType::FLOAT16; - default: - return ComputeType::FLOAT; - } - } - - // This method runs some precomputations on linear weights when possible. - void Model::process_linear_weights() { - if (_device != Device::CPU) - return; // There is currently no processing for non CPU device. - - const bool should_pack_weights = cpu::should_pack_gemm_weights(); - const bool transpose = true; - const float alpha = 1; - - const auto variable_index = _variable_index; - for (const auto& pair : variable_index) { - const std::string& name = pair.first; - if (!is_linear_weight(name)) - continue; - - const StorageView& weight = *pair.second; - const DataType dtype = weight.dtype(); - const dim_t k = weight.dim(1); - const dim_t n = weight.dim(0); - - // If the target Gemm implementation prefers the u8s8s32 format, we can shift - // the input of linear layers to the u8 domain and add a compensation term. - // This term only depends on the linear weight, so we can compute it once and - // store it as a model variable. - if (dtype == DataType::INT8 && cpu::prefer_u8s8s32_gemm()) { - StorageView compensation({n}, DataType::INT32); - primitives::compute_u8_compensation(weight.data(), - transpose, - k, n, - alpha, - compensation.data()); - register_variable(name + "_compensation", std::move(compensation)); - } - - // If requested, linear weights can be packed for the Gemm call. - if (should_pack_weights && is_packable(name)) { - StorageView packed_weight(dtype); - - switch (dtype) { - case DataType::FLOAT: - pack_weight(weight, transpose, k, n, alpha, packed_weight); - break; - case DataType::INT16: - pack_weight(weight, transpose, k, n, alpha, packed_weight); - break; - case DataType::INT8: - pack_weight(weight, transpose, k, n, alpha, packed_weight); - break; - default: - break; - } - - if (!packed_weight.empty()) { - register_variable(name + "_packed", std::move(packed_weight)); - remove_variable(name); // The original weight is no longer needed. - } - } - } - } - static DataType get_dtype_from_item_size(uint8_t item_size) { // This is the old (and flawed) logic of resolving the dtype of saved variables. switch (item_size) { @@ -467,6 +309,97 @@ namespace ctranslate2 { + "(Forward compatibility is not guaranteed.)"); } + // See the model serialization in python/ctranslate2/specs/model_spec.py. + + struct SerializedVariable { + std::string name; + Shape shape; + DataType dtype; + size_t offset; + size_t num_bytes; + + SerializedVariable(std::istream& in, size_t binary_version) { + name = consume(in); + + const size_t rank = consume(in); + const auto* dimensions = consume(in, rank); + shape.assign(dimensions, dimensions + rank); + delete [] dimensions; + + if (binary_version >= 4) { + const auto type_id = consume(in); + dtype = static_cast(type_id); + num_bytes = consume(in); + } else { + const auto item_size = consume(in); + dtype = get_dtype_from_item_size(item_size); + num_bytes = consume(in) * item_size; + } + + offset = in.tellg(); + in.seekg(offset + num_bytes); + } + + // Actually load the variable in memory. + StorageView load(std::istream& in) const { + StorageView variable(shape, dtype); + + const auto previous_offset = in.tellg(); + in.seekg(offset); + consume(in, num_bytes, static_cast(variable.buffer())); + in.seekg(previous_offset); + + return variable; + } + }; + + struct SerializedModel { + size_t spec_revision; + std::string spec_name; + std::vector variables; + + SerializedModel(std::istream& in, size_t binary_version) { + if (binary_version >= 2) { + spec_name = consume(in); + spec_revision = consume(in); + } else { + spec_revision = 1; + } + + const auto num_variables = consume(in); + variables.reserve(num_variables); + + for (uint32_t i = 0; i < num_variables; ++i) { + variables.emplace_back(in, binary_version); + } + + if (binary_version >= 3) { + const auto num_aliases = consume(in); + variables.reserve(num_variables + num_aliases); + + for (uint32_t i = 0; i < num_aliases; ++i) { + auto alias_name = consume(in); + auto variable_name = consume(in); + + auto* variable = get_variable(variable_name); + if (variable) { + variables.emplace_back(*variable); + variables.back().name = std::move(alias_name); + } + } + } + } + + SerializedVariable* get_variable(const std::string& name) { + for (SerializedVariable& variable : variables) { + if (variable.name == name) + return &variable; + } + + return nullptr; + } + }; + std::shared_ptr Model::load(const std::string& path, Device device, int device_index, @@ -485,35 +418,24 @@ namespace ctranslate2 { std::call_once(log_once, log_system_config); } - { - // Check that the device and device index are valid. - ScopedDeviceSetter(device, device_index); - } + const ScopedDeviceSetter scoped_device_setter(device, device_index); std::unique_ptr model_file_ptr = model_reader.get_required_file(binary_file, /*binary=*/true); std::istream& model_file = *model_file_ptr; - // See the model serialization in python/ctranslate2/specs/model_spec.py. - - // Check the binary version and spec revision. const size_t binary_version = consume(model_file); check_version(binary_version, current_binary_version, "binary version"); - std::string spec; - size_t spec_revision; - if (binary_version >= 2) { - spec = consume(model_file); - spec_revision = consume(model_file); - } else { - spec_revision = 1; - } + SerializedModel serialized_model(model_file, binary_version); - auto model = create_model(spec); + auto model = create_model(serialized_model.spec_name); model->_binary_version = binary_version; - model->_spec_revision = spec_revision; + model->_spec_revision = serialized_model.spec_revision; + model->_device = device; + model->_device_index = device_index; - check_version(spec_revision, model->current_spec_revision(), "revision"); + check_version(serialized_model.spec_revision, model->current_spec_revision(), "revision"); { std::unique_ptr config_file_ptr = model_reader.get_file(config_file); @@ -521,54 +443,108 @@ namespace ctranslate2 { model->config = nlohmann::json::parse(*config_file_ptr); } - // Load the variables. - const auto num_variables = consume(model_file); - model->_variable_index.reserve(num_variables); - for (uint32_t i = 0; i < num_variables; ++i) { - auto name = consume(model_file); - const size_t rank = consume(model_file); - const auto* dimensions = consume(model_file, rank); - Shape shape(dimensions, dimensions + rank); - delete [] dimensions; + // Multiple variables can point to the same buffer in the model. + // We will make sure that these variables remain shared. + std::unordered_map> variables_at_offset; + variables_at_offset.reserve(serialized_model.variables.size()); - DataType dtype; - dim_t num_bytes = 0; - if (binary_version >= 4) { - const auto type_id = consume(model_file); - dtype = static_cast(type_id); - num_bytes = consume(model_file); - } else { - const auto item_size = consume(model_file); - dtype = get_dtype_from_item_size(item_size); - num_bytes = consume(model_file) * item_size; + DataType weight_type = DataType::FLOAT; + DataType float_type = DataType::FLOAT; + + for (auto& variable : serialized_model.variables) { + // Models may rename some variables for backward compatibility. + model->update_variable_name(variable.name); + + // Quantization scales will be processed alongside their corresponding weight. + if (ends_with(variable.name, "_scale")) + continue; + + // Scalars can be registered immediately. + if (variable.shape.empty()) { + model->register_variable(variable.name, variable.load(model_file)); + continue; } - StorageView variable(std::move(shape), dtype); - consume(model_file, num_bytes, static_cast(variable.buffer())); - model->register_variable(std::move(name), std::move(variable)); + // Gather some information about the variables type to resolve the default compute type. + if (model->is_quantizable(variable.name)) + weight_type = variable.dtype; + else if (model->is_convertible(variable.name, variable.dtype)) + float_type = variable.dtype; + + variables_at_offset[variable.offset].push_back(&variable); } - // Maybe quantize/dequantize/convert the variables to match the requested compute type. - model->set_compute_type(compute_type, device, device_index); - - // Move variables to the target device. - model->set_device(device, device_index); - - // Register variable aliases. - if (binary_version >= 3) { - const auto num_aliases = consume(model_file); - for (uint32_t i = 0; i < num_aliases; ++i) { - const auto alias = consume(model_file); - const auto variable_name = consume(model_file); - model->register_variable_alias(alias, variable_name); - // Also alias the quantization scale that could be associated to variable_name. - model->register_variable_alias(alias + "_scale", variable_name + "_scale"); + ComputeType default_compute_type = data_type_to_compute_type(weight_type, float_type); + ComputeType effective_compute_type = resolve_compute_type(compute_type, + default_compute_type, + device, + device_index); + + // Update the target dtypes based on the effective compute type. + weight_type = compute_type_to_data_type(effective_compute_type); + float_type = get_default_float_type(effective_compute_type); + + model->_compute_type = compute_type; + model->_effective_compute_type = effective_compute_type; + model->_preferred_size_multiple = get_preferred_size_multiple(effective_compute_type, + device, + device_index); + + const bool quantization_round_mode = model->round_before_cast_in_quantization(); + + for (const auto& pair : variables_at_offset) { + VariablesCollection variables; + + for (const auto* serialized_variable : pair.second) { + const auto& name = serialized_variable->name; + + auto& weight = variables[""]; + if (!weight) + weight = std::make_shared(serialized_variable->load(model_file)); + + if (model->is_quantizable(name)) { + auto& scale = variables["_scale"]; + + if (!scale) { + const auto* serialized_scale = serialized_model.get_variable(name + "_scale"); + scale = (serialized_scale + ? std::make_shared(serialized_scale->load(model_file)) + : std::make_shared()); + } + + convert_weight(name, *weight, *scale, weight_type, quantization_round_mode); + + if (!scale->empty()) { + if (scale->device() != device && !scale->is_scalar()) + *scale = scale->to(device); + model->register_variable(name + "_scale", scale); + } + + if (model->is_linear_weight(name)) { + layers::Dense::register_weight(name, + weight, + *model, + variables, + device, + effective_compute_type, + model->is_packable(name)); + } else { + if (weight->device() != device) + *weight = weight->to(device); + model->register_variable(name, weight); + } + + } else { + if (model->is_convertible(name, weight->dtype()) && weight->dtype() != float_type) + *weight = weight->to(float_type); + if (weight->device() != device) + *weight = weight->to(device); + model->register_variable(name, weight); + } } } // Run additional model initialization. - const ScopedDeviceSetter scoped_device_setter(device, device_index); - model->process_linear_weights(); model->initialize(model_reader); return model; } diff --git a/src/models/transformer.cc b/src/models/transformer.cc index 46cf360c4..356391575 100644 --- a/src/models/transformer.cc +++ b/src/models/transformer.cc @@ -14,25 +14,6 @@ namespace ctranslate2 { return true; } - static std::string map_v1_variable_name(std::string name) { - // V1 variable names were simply the names defined by OpenNMT-tf. - replace(name, "transformer/", ""); - replace(name, ":0", ""); - replace(name, "w_embs", "embeddings/weight"); - replace(name, "kernel", "weight"); - replace(name, "LayerNorm", "layer_norm"); - replace(name, "dense", "projection"); - replace(name, "conv1d_", "linear_"); - replace(name, "conv1d", "linear_0"); - if (name.find("encoder") != std::string::npos) { - replace(name, "multi_head", "self_attention"); - } else { - replace(name, "masked_multi_head", "self_attention"); - replace(name, "multi_head", "attention"); - } - return name; - } - // Empty spec name, TransformerBase, and TransformerBig are there for backward compatibility. static auto register_empty = register_model("", /*num_heads=*/8); @@ -48,6 +29,26 @@ namespace ctranslate2 { return 6; } + void TransformerModel::update_variable_name(std::string& variable_name) const { + if (spec_revision() == 1) { + // In the first specification, variable names were the names defined by OpenNMT-tf V1. + replace(variable_name, "transformer/", ""); + replace(variable_name, ":0", ""); + replace(variable_name, "w_embs", "embeddings/weight"); + replace(variable_name, "kernel", "weight"); + replace(variable_name, "LayerNorm", "layer_norm"); + replace(variable_name, "dense", "projection"); + replace(variable_name, "conv1d_", "linear_"); + replace(variable_name, "conv1d", "linear_0"); + if (variable_name.find("encoder") != std::string::npos) { + replace(variable_name, "multi_head", "self_attention"); + } else { + replace(variable_name, "masked_multi_head", "self_attention"); + replace(variable_name, "multi_head", "attention"); + } + } + } + bool TransformerModel::is_linear_weight(const std::string& variable_name) const { // Linear weights are all variables that are quantizable and not under the "embeddings" scope. return is_quantizable(variable_name) && variable_name.find("embeddings") == std::string::npos; @@ -59,12 +60,6 @@ namespace ctranslate2 { && (!get_vocabulary_map() || variable_name.find("projection") == std::string::npos)); } - void TransformerModel::register_variable(std::string name, StorageView variable) { - if (spec_revision() == 1) - name = map_v1_variable_name(std::move(name)); - SequenceToSequenceModel::register_variable(std::move(name), std::move(variable)); - } - void TransformerModel::initialize(ModelReader& model_reader) { SequenceToSequenceModel::initialize(model_reader); diff --git a/src/ops/gemm.cc b/src/ops/gemm.cc index 6882b8dd6..7a1219ad3 100644 --- a/src/ops/gemm.cc +++ b/src/ops/gemm.cc @@ -102,5 +102,84 @@ namespace ctranslate2 { a_shift_compensation ? a_shift_compensation->data() : nullptr); } + template + static void pack_b(const StorageView& b, + const bool transpose, + const dim_t k, + const dim_t n, + const float alpha, + StorageView& packed) { + const T* src = b.data(); + const dim_t pack_bytes = primitives::gemm_pack_b(src, + transpose, + k, n, + alpha); + + if (pack_bytes == 0) // Packed Gemm is not supported. + throw std::runtime_error("Packed GEMM APIs are not supported by this GEMM backend"); + + const dim_t pack_size = pack_bytes / sizeof (T); + const dim_t b_size = b.size(); + + // We want the packed storage to have the same shape as the original weight + // so that operators can query its shape, but also have enough space to store + // the packed data. + packed.reserve(std::max(b_size, pack_size)); + packed.resize_as(b); + + primitives::gemm_pack_b(src, + transpose, + k, n, + alpha, + packed.data()); + } + + StorageView Gemm::pack_b_input(const StorageView& b, + const bool transpose, + const dim_t k, + const dim_t n, + const float alpha) { + if (b.device() != Device::CPU) + throw std::invalid_argument("Packed GEMM APIs are only defined on CPU"); + + DataType dtype = b.dtype(); + StorageView packed(dtype); + + switch (dtype) { + case DataType::FLOAT: + pack_b(b, transpose, k, n, alpha, packed); + break; + case DataType::INT16: + pack_b(b, transpose, k, n, alpha, packed); + break; + case DataType::INT8: + pack_b(b, transpose, k, n, alpha, packed); + break; + default: + throw std::invalid_argument("Cannot pack GEMM input of type " + dtype_name(dtype)); + break; + } + + return packed; + } + + StorageView Gemm::compensate_u8_a_input(const StorageView& b, + const bool transpose, + const dim_t k, + const dim_t n, + const float alpha) { + if (b.device() != Device::CPU && b.dtype() != DataType::INT8) + throw std::invalid_argument("Unsigned input compensation is only defined for " + "INT8 GEMM on CPU"); + + StorageView compensation({n}, DataType::INT32); + primitives::compute_u8_compensation(b.data(), + transpose, + k, n, + alpha, + compensation.data()); + return compensation; + } + } } diff --git a/src/types.cc b/src/types.cc index 60e40e01b..468644a5a 100644 --- a/src/types.cc +++ b/src/types.cc @@ -229,6 +229,20 @@ namespace ctranslate2 { } } + ComputeType data_type_to_compute_type(const DataType quantizable_type, + const DataType float_type) { + switch (quantizable_type) { + case DataType::INT8: + return float_type == DataType::FLOAT16 ? ComputeType::INT8_FLOAT16 : ComputeType::INT8; + case DataType::INT16: + return ComputeType::INT16; + case DataType::FLOAT16: + return ComputeType::FLOAT16; + default: + return ComputeType::FLOAT; + } + } + DataType get_default_float_type(const ComputeType compute_type) { switch (compute_type) { case ComputeType::FLOAT: diff --git a/src/utils.cc b/src/utils.cc index 80995ef36..30e14e981 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -40,14 +40,16 @@ namespace ctranslate2 { #endif spdlog::info(" - Selected ISA: {}", cpu::isa_to_str(cpu::get_cpu_isa())); spdlog::info(" - Use Intel MKL: {}", cpu::mayiuse_mkl()); - spdlog::info(" - SGEMM backend: {}", - cpu::gemm_backend_to_str(cpu::get_gemm_backend(ComputeType::FLOAT))); - spdlog::info(" - GEMM_S16 backend: {}", - cpu::gemm_backend_to_str(cpu::get_gemm_backend(ComputeType::INT16))); - spdlog::info(" - GEMM_S8 backend: {} (u8s8 preferred: {})", + spdlog::info(" - SGEMM backend: {} (packed: {})", + cpu::gemm_backend_to_str(cpu::get_gemm_backend(ComputeType::FLOAT)), + cpu::pack_gemm_weights(ComputeType::FLOAT)); + spdlog::info(" - GEMM_S16 backend: {} (packed: {})", + cpu::gemm_backend_to_str(cpu::get_gemm_backend(ComputeType::INT16)), + cpu::pack_gemm_weights(ComputeType::INT16)); + spdlog::info(" - GEMM_S8 backend: {} (packed: {}, u8s8 preferred: {})", cpu::gemm_backend_to_str(cpu::get_gemm_backend(ComputeType::INT8)), + cpu::pack_gemm_weights(ComputeType::INT8), cpu::prefer_u8s8s32_gemm()); - spdlog::info(" - Use packed GEMM: {}", cpu::should_pack_gemm_weights()); #ifdef CT2_WITH_CUDA for (int i = 0; i < cuda::get_gpu_count(); ++i) {