diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 021610556..499a29005 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -135,6 +135,9 @@ namespace ctranslate2 { // Returns true if the variable can be converted to another type. virtual bool is_convertible(const StorageView& variable, const std::string& name) const; + // Returns true if the variable should be kept in float32 precision. + virtual bool keep_in_float32(const std::string& variable_name) const; + // Models can override these methods to execute some transformations if needed // (e.g. a variable name changed in a newer spec revision). virtual void register_variable(std::string name, StorageView variable); diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 0cde51ab3..cea6be3d9 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -1002,6 +1002,8 @@ def set_ffn(self, spec, module): self.set_linear(spec.linear_0, module.DenseReluDense.wi) self.set_linear(spec.linear_1, module.DenseReluDense.wo) + spec.linear_1.keep_in_float32 = True + self.set_layer_norm(spec.layer_norm, module.layer_norm) def set_self_attention(self, spec, module): diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py index 14536a653..497634c89 100644 --- a/python/ctranslate2/specs/common_spec.py +++ b/python/ctranslate2/specs/common_spec.py @@ -35,6 +35,7 @@ def __init__(self): self.weight = None self.weight_scale = model_spec.OPTIONAL self.bias = model_spec.OPTIONAL + self.keep_in_float32 = False def has_bias(self): return not isinstance(self.bias, str) diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 79284d743..77ec49ed4 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -183,7 +183,11 @@ def _quantize(spec, name, value): is_quantizable = hasattr(spec, "%s_scale" % key) is_convertible = value.dtype in ("float32", "float16", "bfloat16") - if is_quantizable: + if hasattr(spec, "keep_in_float32") and spec.keep_in_float32.numpy(): + if is_convertible: + value = value.to("float32") + + elif is_quantizable: if quantization == "int16": value = value.to("float32").numpy() # Represent the value with 10 bits so the multiplication is 20 bits diff --git a/src/layers/common.cc b/src/layers/common.cc index 162012e9a..7fa2c1937 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -353,6 +353,10 @@ namespace ctranslate2 { /*trans_b=*/true, output, bias); + } else if (input.dtype() != weight->dtype()) { + StorageView tmp_output(weight->dtype(), weight->device()); + _gemm_op(input.to(weight->dtype()), *weight, tmp_output, nullptr, bias); + output = tmp_output.to(output.dtype()); } else { _gemm_op(input, *weight, output, nullptr, bias); } diff --git a/src/models/model.cc b/src/models/model.cc index 0672494ff..36ec5f603 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -174,6 +174,9 @@ namespace ctranslate2 { const auto& name = variable_pair.first; auto& variable = *variable_pair.second; + if (keep_in_float32(name)) + continue; + // Convert "weight" variables to the expected compute type. // Other float variables (e.g. biases) may be converted to another float type. if (is_quantizable(name)) @@ -253,6 +256,15 @@ namespace ctranslate2 { return !variable.is_scalar() && name.find("_scale") == std::string::npos; } + bool Model::keep_in_float32(const std::string& variable_name) const { + const size_t pos = variable_name.rfind('/'); + if (pos == std::string::npos) + return false; + + const std::string scope = variable_name.substr(0, pos); + return get_flag_with_default(scope + "/keep_in_float32", false); + } + void Model::ensure_dtype(const std::string& name, StorageView& variable, const DataType target_dtype) { @@ -327,6 +339,8 @@ namespace ctranslate2 { for (const auto& variable_pair : _variable_index) { const std::string& name = variable_pair.first; const StorageView& variable = *variable_pair.second; + if (keep_in_float32(name)) + continue; if (is_quantizable(name)) { weight_type = variable.dtype(); } else if (is_convertible(variable, name)) {