diff --git a/.gitignore b/.gitignore index 9c6801c43..bd54a91c1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.pyc /.vs +.vscode /build CMake*.json diff --git a/include/ctranslate2/layers/attention_layer.h b/include/ctranslate2/layers/attention_layer.h index e55ecc5de..941a04913 100644 --- a/include/ctranslate2/layers/attention_layer.h +++ b/include/ctranslate2/layers/attention_layer.h @@ -52,6 +52,7 @@ namespace ctranslate2 { const bool multi_query = false); protected: + bool _is_low_rank; const bool _tensor_parallel; const dim_t _num_heads; const bool _self_attention; diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index 137b926d3..fcd7e8114 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -135,7 +135,9 @@ namespace ctranslate2 { void select_weights(const StorageView* index, const StorageView* extra_bias = nullptr); private: bool _packed_weight; + bool _is_low_rank; const StorageView& _weight; + const StorageView* _weight2; const StorageView* _bias; const StorageView* _qscale; const StorageView* _qzero; @@ -148,6 +150,7 @@ namespace ctranslate2 { const models::QUANTIZATION_TYPE _quant_method; const bool _quantized_gemm; const ops::Gemm _gemm_op; + const ops::Gemm _gemm_op_low_rank; const ops::Quantize _quantize_op; const ops::Dequantize _dequantize_op; const ops::ActivationType* _activation_type; diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 2684dd2c7..0c2e121fc 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -119,7 +119,14 @@ def _load(self): % (config_name, ", ".join(sorted(_MODEL_LOADERS.keys()))) ) - model_class = getattr(transformers, loader.architecture_name) + # If lite whisper use corresponding openai tokenizer + if config.model_type == "lite-whisper": + base_name = self._model_name_or_path.split("/")[-1] # e.g., "lite-whisper-large-v3" + base_name = base_name.replace("lite-", "") # e.g., "whisper-large-v3" + tokenizer_path = f"openai/{base_name}" + else: + tokenizer_path = self._model_name_or_path + tokenizer_class = transformers.AutoTokenizer kwargs = { @@ -137,14 +144,18 @@ def _load(self): if self._trust_remote_code: kwargs["trust_remote_code"] = self._trust_remote_code - model = self.load_model(model_class, self._model_name_or_path, **kwargs) + if hasattr(transformers, loader.architecture_name): + model_class = getattr(transformers, loader.architecture_name) + model = self.load_model(model_class, self._model_name_or_path, **kwargs) + else: + model = transformers.AutoModel.from_pretrained(self._model_name_or_path, **kwargs) tokenizer_kwargs = {} if self._trust_remote_code: tokenizer_kwargs["trust_remote_code"] = self._trust_remote_code tokenizer = self.load_tokenizer( - tokenizer_class, self._model_name_or_path, **tokenizer_kwargs + tokenizer_class, tokenizer_path, **tokenizer_kwargs ) spec = loader(model, tokenizer) @@ -996,6 +1007,119 @@ def set_conv1d(self, spec, module): spec.weight = module.weight spec.bias = module.bias +@register_loader("LiteWhisperConfig") +class LiteWhisperLoader(WhisperLoader): + @property + def architecture_name(self): + return "LiteWhisperForConditionalGeneration" + + def get_model_spec(self, model): + spec = whisper_spec.WhisperSpec( + model.config.encoder_layers, + model.config.encoder_attention_heads, + model.config.decoder_layers, + model.config.decoder_attention_heads, + low_rank=True, + ) + + self.set_encoder(spec.encoder, model.model.encoder) + self.set_decoder(spec.decoder, model.model.decoder) + self.set_linear(spec.decoder.projection, model.proj_out) + + return spec + + + def set_config(self, config, model, tokenizer): + gen_config = getattr(model, "generation_config", None) + + if gen_config is not None: + config.suppress_ids = gen_config.suppress_tokens + config.suppress_ids_begin = gen_config.begin_suppress_tokens + if hasattr(gen_config, "alignment_heads"): + config.alignment_heads = gen_config.alignment_heads + if hasattr(gen_config, "lang_to_id"): + config.lang_ids = sorted(gen_config.lang_to_id.values()) + else: + config.suppress_ids = model.config.suppress_tokens + config.suppress_ids_begin = model.config.begin_suppress_tokens + config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path) + + if getattr(config, "lang_ids", None) is None: + config.lang_ids = self._get_lang_ids_from_tokenizer(tokenizer) + + if config.alignment_heads is None: + config.alignment_heads = _WHISPER_ALIGNMENT_HEADS.get(model.name_or_path) + if config.alignment_heads is None: + # Use the last half layers for alignment by default. + num_layers = model.config.decoder_layers + num_heads = model.config.decoder_attention_heads + config.alignment_heads = list( + itertools.product( + range(num_layers // 2, num_layers), + range(num_heads), + ) + ) + + def set_encoder(self, spec, encoder): + """ + Override encoder mapping for LiteWhisper. + """ + self.set_conv1d(spec.conv1, encoder.conv1) + self.set_conv1d(spec.conv2, encoder.conv2) + + self.set_common_layers(spec, encoder) + + for layer_spec, layer in zip(spec.layer, encoder.layers): + self.set_low_rank_attention( + layer_spec.self_attention, + layer.self_attn, + ) + self.set_layer_norm( + layer_spec.self_attention.layer_norm, + layer.self_attn_layer_norm, + ) + + if hasattr(layer.fc1, "weight1"): + # low rank + self.set_low_rank_linear(layer_spec.ffn.linear_0, layer.fc1) + else: + layer_spec.ffn.linear_0 = common_spec.LinearSpec() + self.set_linear(layer_spec.ffn.linear_0, layer.fc1) + + if hasattr(layer.fc2, "weight1"): + # low rank + self.set_low_rank_linear(layer_spec.ffn.linear_1, layer.fc2) + else: + layer_spec.ffn.linear_1 = common_spec.LinearSpec() + self.set_linear(layer_spec.ffn.linear_1, layer.fc2) + + self.set_layer_norm(layer_spec.ffn.layer_norm, layer.final_layer_norm) + + def set_low_rank_linear(self, spec, module, quant_type=common_spec.Quantization.CT2): + if quant_type == common_spec.Quantization.CT2: + spec.low_rank_weight_1 = module.weight1.transpose(0, 1).contiguous() + spec.low_rank_weight_2 = module.weight2.transpose(0, 1).contiguous() + else: + spec.low_rank_weight_1 = module.qweight1.transpose(0, 1).contiguous() + spec.low_rank_weight_2 = module.qweight2.transpose(0, 1).contiguous() + spec.weight_scale = module.scales + spec.weight_zero = module.qzeros + + if module.bias is not None: + spec.bias = module.bias + + def set_low_rank_or_linear_router(self, spec, module, i): + if hasattr(module, "weight1"): + self.set_low_rank_linear(spec.linear[i], module) + else: + spec.linear[i] = common_spec.LinearSpec() + self.set_linear(spec.linear[i], module) + + def set_low_rank_attention(self, spec, attention): + self.set_low_rank_or_linear_router(spec, attention.q_proj, 0) + self.set_low_rank_or_linear_router(spec, attention.k_proj, 1) + self.set_low_rank_or_linear_router(spec, attention.v_proj, 2) + self.set_low_rank_or_linear_router(spec, attention.out_proj, 3) @register_loader("Wav2Vec2Config") class Wav2Vec2Loader(BartLoader): @@ -2908,6 +3032,7 @@ def main(): (3, 4), ], "openai/whisper-tiny": [(2, 2), (3, 0), (3, 2), (3, 3), (3, 4), (3, 5)], + "efficient-speech/whisper-tiny": [(2, 2), (3, 0), (3, 2), (3, 3), (3, 4), (3, 5)], "openai/whisper-base.en": [(3, 3), (4, 7), (5, 1), (5, 5), (5, 7)], "openai/whisper-base": [ (3, 1), @@ -3021,4 +3146,16 @@ def main(): (24, 1), (25, 6), ], + "efficient-speech/whisper-large-v3": [ + (7, 0), + (10, 17), + (12, 18), + (13, 12), + (16, 1), + (17, 14), + (19, 11), + (21, 4), + (24, 1), + (25, 6), + ], } diff --git a/python/ctranslate2/specs/attention_spec.py b/python/ctranslate2/specs/attention_spec.py index f49d41121..2d61ad8b1 100644 --- a/python/ctranslate2/specs/attention_spec.py +++ b/python/ctranslate2/specs/attention_spec.py @@ -32,13 +32,14 @@ def __init__( num_heads_kv=None, head_dim=None, sliding_window=None, + low_rank=False, ): self.queries_scale = model_spec.OPTIONAL self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm) - self.linear = [ - common_spec.LinearSpec() for _ in range(2 if self_attention else 3) - ] + linear_cls = common_spec.LinearLowRankSpec if low_rank else common_spec.LinearSpec + count = 4 if low_rank else (2 if self_attention else 3) + self.linear = [linear_cls() for _ in range(count)] if relative_position: self.relative_position_keys = None diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py index 598a452d6..4209e41da 100644 --- a/python/ctranslate2/specs/common_spec.py +++ b/python/ctranslate2/specs/common_spec.py @@ -64,3 +64,15 @@ def __init__(self): self.weight = None self.weight_scale = model_spec.OPTIONAL self.multiply_by_sqrt_depth = model_spec.OPTIONAL + + +class LinearLowRankSpec(model_spec.LayerSpec): + def __init__(self): + self.low_rank_weight_1 = None + self.low_rank_weight_2 = None + self.weight_scale = model_spec.OPTIONAL + self.weight_zero = model_spec.OPTIONAL + self.bias = model_spec.OPTIONAL + + def has_bias(self): + return not isinstance(self.bias, str) \ No newline at end of file diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index 230e62cfd..f3f789242 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -253,6 +253,7 @@ def __init__( rms_norm=False, num_heads_kv=None, sliding_window=None, + low_rank=False, ): self.self_attention = attention_spec.MultiHeadAttentionSpec( self_attention=True, @@ -261,8 +262,9 @@ def __init__( rms_norm=rms_norm, num_heads_kv=num_heads_kv, sliding_window=sliding_window, + low_rank=low_rank, ) - self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm) + self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm, low_rank=low_rank) class TransformerDecoderLayerSpec(model_spec.LayerSpec): @@ -340,10 +342,11 @@ def __init__( class FeedForwardSpec(model_spec.LayerSpec): - def __init__(self, glu=False, rms_norm=False): + def __init__(self, glu=False, rms_norm=False, low_rank=False): self.layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm) - self.linear_0 = common_spec.LinearSpec() - self.linear_1 = common_spec.LinearSpec() + linear_cls = common_spec.LinearLowRankSpec if low_rank else common_spec.LinearSpec + self.linear_0 = linear_cls() + self.linear_1 = linear_cls() if glu: self.linear_0_noact = common_spec.LinearSpec() diff --git a/python/ctranslate2/specs/whisper_spec.py b/python/ctranslate2/specs/whisper_spec.py index e32453e1c..d074c1d7b 100644 --- a/python/ctranslate2/specs/whisper_spec.py +++ b/python/ctranslate2/specs/whisper_spec.py @@ -32,6 +32,7 @@ def __init__( num_encoder_heads, num_decoder_layers, num_decoder_heads, + low_rank=False, ): """Initializes the model specification. @@ -40,9 +41,10 @@ def __init__( num_encoder_heads: The number of encoder attention heads. num_decoder_layers: The number of decoder layers. num_decoder_heads: The number of decoder attention heads. + low_rank: Whether to use lite whisper model or not. """ super().__init__() - self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads) + self.encoder = WhisperEncoderSpec(num_encoder_layers, num_encoder_heads, low_rank=low_rank) self.decoder = transformer_spec.TransformerDecoderSpec( num_decoder_layers, num_decoder_heads, @@ -66,12 +68,12 @@ def get_vocabulary_size(self): class WhisperEncoderSpec(model_spec.LayerSpec): - def __init__(self, num_layers, num_heads): + def __init__(self, num_layers, num_heads, low_rank=False): self.num_heads = np.dtype("int16").type(num_heads) self.conv1 = common_spec.Conv1DSpec() self.conv2 = common_spec.Conv1DSpec() self.position_encodings = transformer_spec.PositionEncoderSpec() self.layer_norm = common_spec.LayerNormSpec() self.layer = [ - transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers) + transformer_spec.TransformerEncoderLayerSpec(low_rank=low_rank) for _ in range(num_layers) ] diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 6ad344410..005440c2f 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -362,11 +362,19 @@ namespace ctranslate2 { _linear[0](*q, fused_proj); + if (_is_low_rank) { // support low-rank + _linear[1](*q, keys_proj); + _linear[2](*q, values_proj); + queries_proj = std::move(fused_proj); + } + dim_t beam_size = 1; bool prefilling = (_sliding_window > 0 && values_lengths); if (!_self_attention) { + if (_is_low_rank) + throw std::invalid_argument("lite whisper doesn't use low-rank for cross-attention"); queries_proj = std::move(fused_proj); if (cached_keys == nullptr || cached_keys->empty()) { @@ -401,6 +409,8 @@ namespace ctranslate2 { } else { if (_num_heads_kv < _num_heads) { + if (_is_low_rank) + throw std::invalid_argument("lite whisper doesn't use low-rank for multi-query or GQA"); if (queries_padder) queries_padder->add_padding(fused_proj); @@ -419,8 +429,15 @@ namespace ctranslate2 { } } else { - split_heads(fused_proj, 3 * _num_heads, queries_padder); - ops::Split(1)(fused_proj, queries_proj, keys_proj, values_proj); + if (!_is_low_rank){ + split_heads(fused_proj, 3 * _num_heads, queries_padder); + ops::Split(1)(fused_proj, queries_proj, keys_proj, values_proj); + } + else{ + split_heads(queries_proj, _num_heads, queries_padder); + split_heads(keys_proj, _num_heads_kv, queries_padder); + split_heads(values_proj, _num_heads_kv, queries_padder); + } } if (_rotary_embeddings) { diff --git a/src/layers/attention_layer.cc b/src/layers/attention_layer.cc index c9ae67409..9e91b8e1e 100644 --- a/src/layers/attention_layer.cc +++ b/src/layers/attention_layer.cc @@ -51,10 +51,25 @@ namespace ctranslate2 { return alibi; } + static bool set_low_rank(const models::Model& model, const std::string& scope) { + const dim_t max_layers = 4; + for (int i = 0; i < max_layers; ++i) { + std::string prefix = scope + "/linear_" + std::to_string(i); + const StorageView* w1 = model.get_variable_if_exists(prefix + "/low_rank_weight_1"); + const StorageView* w2 = model.get_variable_if_exists(prefix + "/low_rank_weight_2"); + if (w1 && w2) { + return true; + } + } + // If no low-rank pair is found, then it is not low-rank + return false; + } + static std::vector make_linear_layers(const models::Model& model, const std::string& scope, - bool self_attention) { - const dim_t num_linear_layers = self_attention ? 2 : 3; + bool self_attention, + bool _is_low_rank) { + const dim_t num_linear_layers = !_is_low_rank ? (self_attention ? 2 : 3) : 4; std::vector layers; layers.reserve(num_linear_layers); for (dim_t i = 0; i < num_linear_layers; ++i) @@ -117,11 +132,12 @@ namespace ctranslate2 { bool is_decoder, Alibi* alibi, bool is_flash_attn) - : _tensor_parallel(model.tensor_parallel()) + : _is_low_rank(set_low_rank(model, scope)) + , _tensor_parallel(model.tensor_parallel()) , _num_heads(_tensor_parallel ? SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()) : num_heads) , _self_attention(self_attention) , _is_decoder(is_decoder) - , _linear(make_linear_layers(model, scope, self_attention)) + , _linear(make_linear_layers(model, scope, self_attention, _is_low_rank)) , _d_model(_tensor_parallel ? SAFE_DIVIDE(_linear.back().output_size(), ScopedMPISetter::getNRanks()) : _linear.back().output_size()) , _d_head(model.get_attribute_with_default(scope + "/head_dim", _d_model / _num_heads)) , _pre_norm(pre_norm) diff --git a/src/layers/common.cc b/src/layers/common.cc index c6d1cd0b5..465a164b3 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -250,6 +250,9 @@ namespace ctranslate2 { return _encoding.dim(1); } + static bool has_low_rank(const models::Model& model, const std::string& scope) { + return model.get_variable_if_exists(scope + "/low_rank_weight_1") != nullptr; + } static const StorageView& get_linear_weight(const models::Model& model, const std::string& scope, @@ -268,7 +271,9 @@ namespace ctranslate2 { const ops::ActivationType* activation_type, const bool is_layer_out) : _packed_weight(false) - , _weight(get_linear_weight(model, scope, &_packed_weight)) + , _is_low_rank(has_low_rank(model, scope)) + , _weight(_is_low_rank ? *model.get_variable_if_exists(scope + "/low_rank_weight_1") : get_linear_weight(model, scope, &_packed_weight)) + , _weight2(_is_low_rank ? model.get_variable_if_exists(scope + "/low_rank_weight_2") : nullptr) , _bias(model.get_variable_if_exists(scope + "/bias")) , _qscale(model.get_variable_if_exists(scope + "/weight_scale")) , _qzero(model.get_variable_if_exists(scope + "/weight_zero")) @@ -291,6 +296,13 @@ namespace ctranslate2 { /*a_is_packed=*/false, _packed_weight, _quantized_gemm ? nullptr : activation_type) + , _gemm_op_low_rank(/*alpha=*/1, + /*beta=*/0, + /*trans_a=*/false, + /*trans_b=*/true, + /*a_is_packed=*/false, + /*packaged_weight=*/false, + /*activation_type=*/ nullptr) , _quantize_op(model.use_global_int16_scale() ? ops::Quantize::ScaleType::GLOBAL : ops::Quantize::ScaleType::PER_LAYER, @@ -307,6 +319,11 @@ namespace ctranslate2 { } dim_t Dense::output_size() const { + if (_is_low_rank) { + if (_partial_weight) + throw std::runtime_error("Low rank dense layer does not support partial weights"); + return _weight2->dim(0); + } return _partial_weight ? _partial_weight.dim(0) : _weight.dim(0); } @@ -338,8 +355,11 @@ namespace ctranslate2 { void Dense::operator()(const StorageView& input, StorageView& output) const { PROFILE("Dense"); + if (_is_low_rank && !_partial_weight.empty()) + throw std::runtime_error("Low rank dense layer does not support partial weights"); const StorageView* qscale = _partial_qscale.empty() ? _qscale : &_partial_qscale; const StorageView* weight = _partial_weight.empty() ? &_weight : &_partial_weight; + const StorageView* weight2 = _is_low_rank ? _weight2 : nullptr; const StorageView* bias = _partial_bias.empty() ? _bias : &_partial_bias; const StorageView* compensation = (_partial_u8_shift_compensation.empty() ? _u8_shift_compensation @@ -349,6 +369,8 @@ namespace ctranslate2 { if (affected_by_tp && ScopedMPISetter::getCurRank() != 0) bias = nullptr; if (_quantized_gemm) { + if (_is_low_rank) + throw std::runtime_error("Low rank dense layer is not supported with quantized gemm"); const auto device = input.device(); StorageView qinput(_weight.dtype(), device); StorageView qinput_scale(_qscale->dtype(), device); @@ -396,6 +418,8 @@ namespace ctranslate2 { output, bias); } else if (_qzero && _qscale) { + if (_is_low_rank) + throw std::runtime_error("Low rank dense layer is not supported with quantized gemm"); switch (_quant_method) { case models::QUANTIZATION_TYPE::AWQ_GEMM: if (input.dim(0) * input.dim(1) >= 1024) { @@ -428,7 +452,17 @@ namespace ctranslate2 { "support only ct2 and awq quantization"); } } else { - _gemm_op(input, *weight, output, nullptr, bias); + if(!_is_low_rank) { + _gemm_op(input, *weight, output, nullptr, bias); + } else { + StorageView intermediate_output(input.device(), input.dtype()); + + // First multiplication: input [M,K] * weight^T [K,R] + _gemm_op_low_rank(input, *weight, intermediate_output, nullptr); + + // Second multiplication: intermediate [M,R] * weight2^T [R,N] + _gemm_op(intermediate_output, *weight2, output, nullptr, bias); + } } }