Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep FFN output layer in float32 for T5 models #1239

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Keep FFN output layer in float32 for T5 models
guillaumekln committed May 23, 2023
commit 4bfc6b334978c4fd4026ca5f17a54dbfb0c3dfe7
3 changes: 3 additions & 0 deletions include/ctranslate2/models/model.h
Original file line number Diff line number Diff line change
@@ -133,6 +133,9 @@ namespace ctranslate2 {
// Returns true if the variable can be converted to another type.
virtual bool is_convertible(const StorageView& variable, const std::string& name) const;

// Returns true if the variable should be kept in float32 precision.
virtual bool keep_in_float32(const std::string& variable_name) const;

// Models can override these methods to execute some transformations if needed
// (e.g. a variable name changed in a newer spec revision).
virtual void register_variable(std::string name, StorageView variable);
2 changes: 2 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
@@ -955,6 +955,8 @@ def set_ffn(self, spec, module):
self.set_linear(spec.linear_0, module.DenseReluDense.wi)

self.set_linear(spec.linear_1, module.DenseReluDense.wo)
spec.linear_1.keep_in_float32 = True

self.set_layer_norm(spec.layer_norm, module.layer_norm)

def set_self_attention(self, spec, module):
1 change: 1 addition & 0 deletions python/ctranslate2/specs/common_spec.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ def __init__(self):
self.weight = None
self.weight_scale = model_spec.OPTIONAL
self.bias = model_spec.OPTIONAL
self.keep_in_float32 = False

def has_bias(self):
return isinstance(self.bias, np.ndarray)
6 changes: 6 additions & 0 deletions python/ctranslate2/specs/model_spec.py
Original file line number Diff line number Diff line change
@@ -166,6 +166,12 @@ def _quantize(spec, name, value):
return

key = _split_scope(name)[-1]

if getattr(spec, "keep_in_float32", False):
if value.dtype == np.float16:
setattr(spec, key, value.astype(np.float32))
return

scale = None
is_quantizable = hasattr(spec, "%s_scale" % key)

3 changes: 3 additions & 0 deletions src/layers/common.cc
Original file line number Diff line number Diff line change
@@ -353,6 +353,9 @@ namespace ctranslate2 {
/*trans_b=*/true,
output,
bias);
} else if (input.dtype() != weight->dtype()) {
_gemm_op(input.to(weight->dtype()), *weight, output, nullptr, bias);
output = output.to(input.dtype());
} else {
_gemm_op(input, *weight, output, nullptr, bias);
}
14 changes: 14 additions & 0 deletions src/models/model.cc
Original file line number Diff line number Diff line change
@@ -170,6 +170,9 @@ namespace ctranslate2 {
const auto& name = variable_pair.first;
auto& variable = *variable_pair.second;

if (keep_in_float32(name))
continue;

// Convert "weight" variables to the expected compute type.
// Other float variables (e.g. biases) may be converted from or to float16.
if (is_quantizable(name))
@@ -249,6 +252,15 @@ namespace ctranslate2 {
return !variable.is_scalar() && name.find("_scale") == std::string::npos;
}

bool Model::keep_in_float32(const std::string& variable_name) const {
const size_t pos = variable_name.rfind('/');
if (pos == std::string::npos)
return false;

const std::string scope = variable_name.substr(0, pos);
return get_flag_with_default(scope + "/keep_in_float32", false);
}

void Model::ensure_dtype(const std::string& name,
StorageView& variable,
const DataType target_dtype) {
@@ -330,6 +342,8 @@ namespace ctranslate2 {
for (const auto& variable_pair : _variable_index) {
const std::string& name = variable_pair.first;
const StorageView& variable = *variable_pair.second;
if (keep_in_float32(name))
continue;
if (is_quantizable(name)) {
weight_type = variable.dtype();
} else if (is_convertible(variable, name)) {