From 39f48f2e843df52245e6c857326e1115bca12b03 Mon Sep 17 00:00:00 2001 From: Minh-Thuc <46375464+minhthuc2502@users.noreply.github.com> Date: Thu, 4 Jul 2024 10:27:40 +0200 Subject: [PATCH] Quantzation AWQ GEMM + GEMV (#1727) * quantzation awq gemm + gemv * fix pipeline * fix pipeline * fix pipeline * fix dequantize awq * remove duplicated code --- CMakeLists.txt | 10 + README.md | 2 +- docs/quantization.md | 19 + include/ctranslate2/layers/common.h | 3 + include/ctranslate2/models/model.h | 17 +- include/ctranslate2/ops/awq/dequantize_awq.h | 26 + include/ctranslate2/ops/awq/gemm.h | 27 + include/ctranslate2/ops/awq/gemv.h | 33 ++ include/ctranslate2/ops/gemm.h | 3 +- include/ctranslate2/ops/mean.h | 3 +- include/ctranslate2/ops/ops.h | 4 + include/ctranslate2/ops/sum.h | 17 + python/ctranslate2/converters/transformers.py | 150 ++++- python/ctranslate2/converters/utils.py | 19 + python/ctranslate2/specs/common_spec.py | 9 + python/ctranslate2/specs/transformer_spec.py | 34 +- src/layers/common.cc | 35 ++ src/models/model.cc | 100 ++-- src/ops/awq/dequantize.cc | 24 + src/ops/awq/dequantize.cuh | 79 +++ src/ops/awq/dequantize_cpu.cc | 23 + src/ops/awq/dequantize_gpu.cu | 115 ++++ src/ops/awq/gemm.cc | 34 ++ src/ops/awq/gemm_cpu.cc | 25 + src/ops/awq/gemm_gpu.cu | 543 +++++++++++++++++ src/ops/awq/gemv.cc | 39 ++ src/ops/awq/gemv_cpu.cc | 40 ++ src/ops/awq/gemv_gpu.cu | 548 ++++++++++++++++++ src/ops/mean.cc | 2 +- src/ops/mean_cpu.cc | 6 +- src/ops/mean_gpu.cu | 8 +- src/ops/sum.cc | 44 ++ 32 files changed, 1964 insertions(+), 77 deletions(-) create mode 100644 include/ctranslate2/ops/awq/dequantize_awq.h create mode 100644 include/ctranslate2/ops/awq/gemm.h create mode 100644 include/ctranslate2/ops/awq/gemv.h create mode 100644 include/ctranslate2/ops/sum.h create mode 100644 src/ops/awq/dequantize.cc create mode 100644 src/ops/awq/dequantize.cuh create mode 100644 src/ops/awq/dequantize_cpu.cc create mode 100644 src/ops/awq/dequantize_gpu.cu create mode 100644 src/ops/awq/gemm.cc create mode 100644 src/ops/awq/gemm_cpu.cc create mode 100644 src/ops/awq/gemm_gpu.cu create mode 100644 src/ops/awq/gemv.cc create mode 100644 src/ops/awq/gemv_cpu.cc create mode 100644 src/ops/awq/gemv_gpu.cu create mode 100644 src/ops/sum.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 18eeb826c..ac94aac57 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -191,6 +191,13 @@ set(SOURCES src/ops/transpose.cc src/ops/nccl_ops.cc src/ops/nccl_ops_cpu.cc + src/ops/awq/dequantize.cc + src/ops/awq/dequantize_cpu.cc + src/ops/awq/gemm.cc + src/ops/awq/gemm_cpu.cc + src/ops/awq/gemv.cc + src/ops/awq/gemv_cpu.cc + src/ops/sum.cc src/padder.cc src/profiler.cc src/random.cc @@ -595,6 +602,9 @@ if (WITH_CUDA) src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu + src/ops/awq/gemm_gpu.cu + src/ops/awq/gemv_gpu.cu + src/ops/awq/dequantize_gpu.cu ) set_source_files_properties( diff --git a/README.md b/README.md index 7ce65486b..bfb64c851 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ The project is production-oriented and comes with [backward compatibility guaran ## Key features * **Fast and efficient execution on CPU and GPU**
The execution [is significantly faster and requires less resources](#benchmarks) than general-purpose deep learning frameworks on supported models and tasks thanks to many advanced optimizations: layer fusion, padding removal, batch reordering, in-place operations, caching mechanism, etc. -* **Quantization and reduced precision**
The model serialization and computation support weights with [reduced precision](https://opennmt.net/CTranslate2/quantization.html): 16-bit floating points (FP16), 16-bit brain floating points (BF16), 16-bit integers (INT16), and 8-bit integers (INT8). +* **Quantization and reduced precision**
The model serialization and computation support weights with [reduced precision](https://opennmt.net/CTranslate2/quantization.html): 16-bit floating points (FP16), 16-bit brain floating points (BF16), 16-bit integers (INT16), 8-bit integers (INT8) and AWQ quantization (INT4). * **Multiple CPU architectures support**
The project supports x86-64 and AArch64/ARM64 processors and integrates multiple backends that are optimized for these platforms: [Intel MKL](https://software.intel.com/content/www/us/en/develop/tools/oneapi/components/onemkl.html), [oneDNN](https://github.com/oneapi-src/oneDNN), [OpenBLAS](https://www.openblas.net/), [Ruy](https://github.com/google/ruy), and [Apple Accelerate](https://developer.apple.com/documentation/accelerate). * **Automatic CPU detection and code dispatch**
One binary can include multiple backends (e.g. Intel MKL and oneDNN) and instruction set architectures (e.g. AVX, AVX2) that are automatically selected at runtime based on the CPU information. * **Parallel and asynchronous execution**
Multiple batches can be processed in parallel and asynchronously using multiple GPUs or CPU cores. diff --git a/docs/quantization.md b/docs/quantization.md index aa1d247d4..296c57000 100644 --- a/docs/quantization.md +++ b/docs/quantization.md @@ -6,6 +6,7 @@ Quantization is a technique that can reduce the model size and accelerate its ex * 16-bit integers (INT16) * 16-bit floating points (FP16) * 16-bit brain floating points (BF16) +* 4-bit AWQ Quantization ```{tip} See the benchmark results in the main [README](https://github.com/OpenNMT/CTranslate2#benchmarks) to compare the performance and memory usage with and without quantization. @@ -161,3 +162,21 @@ In this mode, all model weights are stored in half precision and all layers are * NVIDIA GPU with Compute Capability >= 8.0 In this mode, all model weights are stored in BF16 and all layers are run with this type. + +### 4-bit AWQ + +The compute type would be `int32_float16` + +**Supported on:** + +* NVIDIA GPU with Compute Capability >= 7.5 + +In this mode, all model weights are stored in half precision and all layers are run in half precision. Other parameters like scale and zero are stored in ``int32``. + +For example, + +```bash + ct2-transformers-converter --model TheBloke/Llama-2-7B-AWQ --copy_files tokenizer.model --output_dir ct2_model +``` + +We have to quantize the model with AWQ first, then convert it to CT2 format. \ No newline at end of file diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index 7ea5e9126..3985b3feb 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -138,16 +138,19 @@ namespace ctranslate2 { const StorageView& _weight; const StorageView* _bias; const StorageView* _qscale; + const StorageView* _qzero; const StorageView* _u8_shift_compensation; StorageView _partial_weight; StorageView _partial_bias; StorageView _partial_qscale; StorageView _partial_u8_shift_compensation; const DataType _output_type; + const models::QUANTIZATION_TYPE _quant_method; const bool _quantized_gemm; const ops::Gemm _gemm_op; const ops::Quantize _quantize_op; const ops::Dequantize _dequantize_op; + const ops::ActivationType* _activation_type; const bool _is_layer_out; }; diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index babba1455..32e4f8403 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -11,6 +11,12 @@ namespace ctranslate2 { namespace models { + enum class QUANTIZATION_TYPE { + CT2, + AWQ_GEMM, + AWQ_GEMV + }; + static const size_t current_binary_version = 6; // Checks whether the provided path could contain a CTranslate2 model. @@ -90,6 +96,14 @@ namespace ctranslate2 { return _use_flash_attention; } + QUANTIZATION_TYPE quant_method() const { + return _quant_method; + } + + void set_quant_method(QUANTIZATION_TYPE type) { + _quant_method = type; + } + virtual bool use_global_int16_scale() const { return true; } @@ -160,7 +174,7 @@ namespace ctranslate2 { private: void process_linear_weights(); - void set_compute_type(ComputeType type, Device device, int device_index); + void set_compute_type(ComputeType type, Device device, int device_index, bool update_weight=true); void ensure_dtype(const std::string& name, StorageView& variable, const DataType target_dtype); @@ -177,6 +191,7 @@ namespace ctranslate2 { std::unordered_map> _variable_index; bool _use_flash_attention = false; bool _tensor_parallel = false; + QUANTIZATION_TYPE _quant_method = QUANTIZATION_TYPE::CT2; }; template<> diff --git a/include/ctranslate2/ops/awq/dequantize_awq.h b/include/ctranslate2/ops/awq/dequantize_awq.h new file mode 100644 index 000000000..3b1a20f6a --- /dev/null +++ b/include/ctranslate2/ops/awq/dequantize_awq.h @@ -0,0 +1,26 @@ +#pragma once + +#include "../op.h" + +namespace ctranslate2 { + namespace ops { + + class DequantizeAwq : public Op { + public: + DequantizeAwq(); + + void operator()(const StorageView& input, + const StorageView& scale, + const StorageView& zeros, + StorageView& output) const; + + private: + template + void dequantize(const StorageView& input, + const StorageView& scale, + const StorageView& zeros, + StorageView& output) const; + }; + + } +} diff --git a/include/ctranslate2/ops/awq/gemm.h b/include/ctranslate2/ops/awq/gemm.h new file mode 100644 index 000000000..8233f12f7 --- /dev/null +++ b/include/ctranslate2/ops/awq/gemm.h @@ -0,0 +1,27 @@ +#pragma once + +#include "../activation.h" +#include "../gemm.h" + +namespace ctranslate2 { + namespace ops { + class GemmAwq : public Gemm { + public: + using Gemm::Gemm; + void operator()(const StorageView& a, + const StorageView& b, + const StorageView& scale, + const StorageView& zero, + StorageView& c, + const StorageView* bias = nullptr) const; + + private: + template + void compute(const StorageView& a, + const StorageView& b, + const StorageView& scale, + const StorageView& zero, + StorageView& c) const; + }; + } +} \ No newline at end of file diff --git a/include/ctranslate2/ops/awq/gemv.h b/include/ctranslate2/ops/awq/gemv.h new file mode 100644 index 000000000..62866e515 --- /dev/null +++ b/include/ctranslate2/ops/awq/gemv.h @@ -0,0 +1,33 @@ +#pragma once + +#include "../activation.h" +#include "../gemm.h" + +namespace ctranslate2 { + namespace ops { + class GemvAwq : public Gemm { + public: + using Gemm::Gemm; + void operator()(const StorageView& a, + const StorageView& b, + const StorageView& scale, + const StorageView& zero, + StorageView& c, + const StorageView* bias = nullptr) const; + + private: + template + void compute_gemv(const StorageView& a, + const StorageView& b, + const StorageView& scale, + const StorageView& zero, + StorageView& c) const; + template + void compute_gemv2(const StorageView& a, + const StorageView& b, + const StorageView& scale, + const StorageView& zero, + StorageView& c) const; + }; + } +} \ No newline at end of file diff --git a/include/ctranslate2/ops/gemm.h b/include/ctranslate2/ops/gemm.h index c309063d6..3c4efbb02 100644 --- a/include/ctranslate2/ops/gemm.h +++ b/include/ctranslate2/ops/gemm.h @@ -39,6 +39,8 @@ namespace ctranslate2 { const dim_t k, const dim_t n, const float alpha); + protected: + const ActivationType* _activation_type; private: float _alpha; @@ -47,7 +49,6 @@ namespace ctranslate2 { bool _trans_b; bool _a_is_packed; bool _b_is_packed; - const ActivationType* _activation_type; template void compute(const StorageView& a, diff --git a/include/ctranslate2/ops/mean.h b/include/ctranslate2/ops/mean.h index 251f342dc..501394c3b 100644 --- a/include/ctranslate2/ops/mean.h +++ b/include/ctranslate2/ops/mean.h @@ -11,12 +11,13 @@ namespace ctranslate2 { void operator()(const StorageView& input, StorageView& output) const override; - private: + protected: template void compute(const StorageView& input, const dim_t outer_size, const dim_t axis_size, const dim_t inner_size, + const bool get_sum, StorageView& output) const; const dim_t _axis; diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h index ceca49450..ed9db4265 100644 --- a/include/ctranslate2/ops/ops.h +++ b/include/ctranslate2/ops/ops.h @@ -39,3 +39,7 @@ #include "slide.h" #include "nccl_ops.h" #include "flash_attention.h" +#include "awq/gemm.h" +#include "awq/gemv.h" +#include "awq/dequantize_awq.h" +#include "sum.h" diff --git a/include/ctranslate2/ops/sum.h b/include/ctranslate2/ops/sum.h new file mode 100644 index 000000000..3a240d8fc --- /dev/null +++ b/include/ctranslate2/ops/sum.h @@ -0,0 +1,17 @@ +#pragma once + +#include "op.h" +#include "mean.h" + +namespace ctranslate2 { + namespace ops { + + class Sum : public Mean { + public: + Sum(const dim_t axis); + + void operator()(const StorageView& input, StorageView& output) const override; + }; + + } +} diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 719983a3d..849123a50 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -43,6 +43,11 @@ "su": attention_spec.RotaryScalingType.Su, } +_SUPPORTED_QUANTIZATION = { + "gemm": common_spec.Quantization.AWQ_GEMM, + "gemv": common_spec.Quantization.AWQ_GEMV, +} + _MODEL_LOADERS = {} @@ -217,8 +222,14 @@ def set_layer_norm(self, spec, module): spec.gamma = module.weight spec.beta = module.bias - def set_linear(self, spec, module): - spec.weight = module.weight + def set_linear(self, spec, module, quant_type=common_spec.Quantization.CT2): + if quant_type == common_spec.Quantization.CT2: + spec.weight = module.weight + else: + spec.weight = module.qweight + spec.weight_scale = module.scales + spec.weight_zero = module.qzeros + if isinstance(module, transformers.Conv1D): spec.weight = spec.weight.transpose(0, 1) if module.bias is not None: @@ -1407,6 +1418,26 @@ def get_model_spec(self, model): rotary_scaling_type = None rotary_scaling_factor = 1 + quantization_config = getattr(model.config, "quantization_config", None) + if quantization_config: + if quantization_config.quant_method == "awq": + quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version) + if quant_type is None: + raise NotImplementedError( + "Quantization type '%s' is not yet implemented. " + "The following Quantization types are currently supported: %s" + % ( + quantization_config.quant_method, + ", ".join(_SUPPORTED_QUANTIZATION.keys()), + ) + ) + quant_group_size = quantization_config.group_size + quant_bits = quantization_config.bits + else: + quant_type = common_spec.Quantization.CT2 + quant_group_size = None + quant_bits = None + spec = transformer_spec.TransformerDecoderModelSpec.from_config( num_layers, num_heads, @@ -1420,9 +1451,12 @@ def get_model_spec(self, model): rotary_scaling_factor=rotary_scaling_factor, rotary_base=getattr(model.config, "rope_theta", 10000), num_heads_kv=num_heads_kv, + quant_type=quant_type, + quant_group_size=quant_group_size, + quant_bits=quant_bits, ) - self.set_decoder(spec.decoder, model.model) + self.set_decoder(spec.decoder, model.model, quant_type) self.set_linear(spec.decoder.projection, model.lm_head) return spec @@ -1451,7 +1485,7 @@ def set_config(self, config, model, tokenizer): def set_layer_norm(self, spec, layer_norm): spec.gamma = layer_norm.weight - def set_decoder(self, spec, module): + def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2): spec.scale_embeddings = False self.set_embeddings(spec.embeddings, module.embed_tokens) self.set_layer_norm(spec.layer_norm, module.norm) @@ -1464,17 +1498,39 @@ def set_decoder(self, spec, module): layer_spec.ffn.layer_norm, layer.post_attention_layernorm ) - wq = layer.self_attn.q_proj.weight - wk = layer.self_attn.k_proj.weight - wv = layer.self_attn.v_proj.weight - wo = layer.self_attn.o_proj.weight + split_layers = [common_spec.LinearSpec() for _ in range(3)] + self.set_linear( + split_layers[0], layer.self_attn.q_proj, quant_type=quant_type + ) + self.set_linear( + split_layers[1], layer.self_attn.k_proj, quant_type=quant_type + ) + self.set_linear( + split_layers[2], layer.self_attn.v_proj, quant_type=quant_type + ) - layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv]) - layer_spec.self_attention.linear[1].weight = wo + if quant_type == common_spec.Quantization.CT2: + utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers) + else: + cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0 + utils.fuse_linear_prequant( + layer_spec.self_attention.linear[0], split_layers, cc_dim + ) + self.set_linear( + layer_spec.self_attention.linear[1], + layer.self_attn.o_proj, + quant_type=quant_type, + ) - self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj) - self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj) - self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj) + self.set_linear( + layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type + ) + self.set_linear( + layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type + ) + self.set_linear( + layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type + ) delattr(layer, "self_attn") delattr(layer, "mlp") @@ -1512,6 +1568,26 @@ def get_model_spec(self, model): rotary_scaling_type = None rotary_scaling_factor = 1 + quantization_config = getattr(model.config, "quantization_config", None) + if quantization_config: + if quantization_config.quant_method == "awq": + quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version) + if quant_type is None: + raise NotImplementedError( + "Quantization type '%s' is not yet implemented. " + "The following Quantization types are currently supported: %s" + % ( + quantization_config.quant_method, + ", ".join(_SUPPORTED_QUANTIZATION.keys()), + ) + ) + quant_group_size = quantization_config.group_size + quant_bits = quantization_config.bits + else: + quant_type = common_spec.Quantization.CT2 + quant_group_size = None + quant_bits = None + spec = transformer_spec.TransformerDecoderModelSpec.from_config( num_layers, num_heads, @@ -1526,9 +1602,12 @@ def get_model_spec(self, model): rotary_base=getattr(model.config, "rope_theta", 10000), num_heads_kv=num_heads_kv, sliding_window=sliding_window, + quant_type=quant_type, + quant_group_size=quant_group_size, + quant_bits=quant_bits, ) - self.set_decoder(spec.decoder, model.model) + self.set_decoder(spec.decoder, model.model, quant_type=quant_type) self.set_linear(spec.decoder.projection, model.lm_head) return spec @@ -1553,7 +1632,7 @@ def set_config(self, config, model, tokenizer): def set_layer_norm(self, spec, layer_norm): spec.gamma = layer_norm.weight - def set_decoder(self, spec, module): + def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2): spec.scale_embeddings = False self.set_embeddings(spec.embeddings, module.embed_tokens) self.set_layer_norm(spec.layer_norm, module.norm) @@ -1565,18 +1644,39 @@ def set_decoder(self, spec, module): self.set_layer_norm( layer_spec.ffn.layer_norm, layer.post_attention_layernorm ) + split_layers = [common_spec.LinearSpec() for _ in range(3)] + self.set_linear( + split_layers[0], layer.self_attn.q_proj, quant_type=quant_type + ) + self.set_linear( + split_layers[1], layer.self_attn.k_proj, quant_type=quant_type + ) + self.set_linear( + split_layers[2], layer.self_attn.v_proj, quant_type=quant_type + ) - wq = layer.self_attn.q_proj.weight - wk = layer.self_attn.k_proj.weight - wv = layer.self_attn.v_proj.weight - wo = layer.self_attn.o_proj.weight - - layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv]) - layer_spec.self_attention.linear[1].weight = wo + if quant_type == common_spec.Quantization.CT2: + utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers) + else: + cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0 + utils.fuse_linear_prequant( + layer_spec.self_attention.linear[0], split_layers, cc_dim + ) + self.set_linear( + layer_spec.self_attention.linear[1], + layer.self_attn.o_proj, + quant_type=quant_type, + ) - self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj) - self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj) - self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj) + self.set_linear( + layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type + ) + self.set_linear( + layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type + ) + self.set_linear( + layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type + ) delattr(layer, "self_attn") delattr(layer, "mlp") diff --git a/python/ctranslate2/converters/utils.py b/python/ctranslate2/converters/utils.py index 48d9d27a1..8ce9e937b 100644 --- a/python/ctranslate2/converters/utils.py +++ b/python/ctranslate2/converters/utils.py @@ -33,6 +33,25 @@ def fuse_linear(spec, layers): ) +def fuse_linear_prequant(spec, layers, axis): + if not layers: + raise ValueError("Cannot fuse linear layers: at least one layer is required") + params = ["weight", "weight_scale", "weight_zero"] + if isinstance(layers[0].weight, np.ndarray): + concatenate = np.concatenate + else: + import torch + + concatenate = torch.cat + + for param in params: + setattr( + spec, + param, + concatenate([getattr(layer, param) for layer in layers], axis=axis), + ) + + def permute_for_sliced_rotary(weight, num_heads, rotary_dim=None): """Permutes the weight to use the sliced rotary implementation.""" if rotary_dim is not None: diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py index b517ef77c..b1162839c 100644 --- a/python/ctranslate2/specs/common_spec.py +++ b/python/ctranslate2/specs/common_spec.py @@ -23,6 +23,14 @@ class EmbeddingsMerge(enum.IntEnum): ADD = 1 +class Quantization(enum.IntEnum): + """Activation type.""" + + CT2 = 0 + AWQ_GEMM = 1 + AWQ_GEMV = 2 + + class LayerNormSpec(model_spec.LayerSpec): def __init__(self, rms_norm=False): self.gamma = None @@ -36,6 +44,7 @@ class LinearSpec(model_spec.LayerSpec): def __init__(self): self.weight = None self.weight_scale = model_spec.OPTIONAL + self.weight_zero = model_spec.OPTIONAL self.bias = model_spec.OPTIONAL def has_bias(self): diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index 2325f7bbf..abb812c8b 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -105,6 +105,9 @@ def __init__( num_heads_kv: Optional[int] = None, head_dim: Optional[int] = None, sliding_window: Optional[int] = None, + quant_type: Optional[common_spec.Quantization] = None, + quant_group_size: Optional[int] = None, + quant_bits: Optional[int] = None, ): """Initializes a Transformer decoder specification. @@ -147,7 +150,12 @@ def __init__( multi_query_attention: Use multi-query attention (alias for num_heads_kv=1). num_heads_kv: Number of attention heads for the key and value. sliding_window: Max sequence length to retain in KV Cache. + quant_type: quantization type used (like awq... for lower bit quantization) + quant_group_size: group size of the lower bit quantization + quant_bits: number of bit of the quantization (ex: 4bit) """ + + self._config = dict() if parallel_residual: if not pre_norm: raise ValueError("The GPT-J block expects a pre-norm architecture") @@ -215,7 +223,7 @@ def __init__( for _ in range(num_layers) ] self.start_from_zero_embedding = False - self.multi_query_attention = multi_query_attention or ( + self._config["multi_query_attention"] = multi_query_attention or ( num_heads_kv != num_heads ) @@ -223,6 +231,15 @@ def __init__( self.project_in = common_spec.LinearSpec() self.project_out = common_spec.LinearSpec() + if quant_type is not None: + self._config["quantization_type"] = quant_type + self._config["quantization_bits"] = quant_bits + self._config["quantization_group_size"] = quant_group_size + + @property + def config(self): + return self._config + class TransformerEncoderLayerSpec(model_spec.LayerSpec): def __init__( @@ -485,9 +502,8 @@ def __init__(self, decoder: TransformerDecoderSpec): super().__init__() self.decoder = decoder - self._config.add_attribute( - "multi_query_attention", self.decoder.multi_query_attention - ) + for key, value in self.decoder.config.items(): + self._config.add_attribute(key, value) @classmethod def from_config( @@ -518,6 +534,9 @@ def from_config( num_heads_kv: Optional[int] = None, head_dim: Optional[int] = None, sliding_window: Optional[int] = None, + quant_type: Optional[common_spec.Quantization] = None, + quant_group_size: Optional[int] = None, + quant_bits: Optional[int] = None, ): """Creates a Transformer decoder model specification. @@ -553,7 +572,11 @@ def from_config( attention layer norms. multi_query_attention: Use multi-query attention (alias for num_heads_kv=1). num_heads_kv: Number of attention heads for the key and value. + head_dim: Number of head sliding_window: max sequence length to retain KV cache + quant_type: quantization type used (like awq... for lower bit quantization) + quant_group_size: group size of the lower bit quantization + quant_bits: number of bit of the quantization (ex: 4bit) """ decoder = TransformerDecoderSpec( num_layers, @@ -583,6 +606,9 @@ def from_config( num_heads_kv=num_heads_kv, head_dim=head_dim, sliding_window=sliding_window, + quant_type=quant_type, + quant_group_size=quant_group_size, + quant_bits=quant_bits, ) return cls(decoder) diff --git a/src/layers/common.cc b/src/layers/common.cc index 6f56c01ef..86fb66a7d 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -271,6 +271,7 @@ namespace ctranslate2 { , _weight(get_linear_weight(model, scope, &_packed_weight)) , _bias(model.get_variable_if_exists(scope + "/bias")) , _qscale(model.get_variable_if_exists(scope + "/weight_scale")) + , _qzero(model.get_variable_if_exists(scope + "/weight_zero")) , _u8_shift_compensation((_weight.device() == Device::CPU && _weight.dtype() == DataType::INT8 && cpu::prefer_u8s8s32_gemm()) @@ -281,6 +282,7 @@ namespace ctranslate2 { , _partial_qscale(_weight.device(), DataType::FLOAT32) , _partial_u8_shift_compensation(_weight.device(), DataType::INT32) , _output_type(get_default_float_type(model.effective_compute_type())) + , _quant_method(model.quant_method()) , _quantized_gemm(_weight.dtype() == DataType::INT16 || _weight.dtype() == DataType::INT8) , _gemm_op(/*alpha=*/1, /*beta=*/0, @@ -295,6 +297,7 @@ namespace ctranslate2 { /*shift_to_uint8=*/bool(_u8_shift_compensation), /*round_before_cast=*/model.round_before_cast_in_quantization()) , _dequantize_op(activation_type) + , _activation_type(activation_type) , _is_layer_out(is_layer_out) { } @@ -392,6 +395,38 @@ namespace ctranslate2 { /*trans_b=*/true, output, bias); + } else if (_qzero && _qscale) { + switch (_quant_method) { + case models::QUANTIZATION_TYPE::AWQ_GEMM: + if (input.dim(0) * input.dim(1) >= 1024) { + StorageView weight_dequant(input.dtype(), input.device()); + ops::DequantizeAwq dequantize_awq_op; + dequantize_awq_op(*weight, *qscale, *_qzero, weight_dequant); + ops::Gemm gemm_op(/*alpha=*/1, + /*beta=*/0, + /*trans_a=*/false, + /*trans_b=*/false, + /*a_is_packed=*/false, + /*b_is_packed*/false, + _activation_type); + gemm_op(input, weight_dequant, output, nullptr, bias); + } else { + ops::GemmAwq gemm_awq_op(/*alpha=*/1, /*beta=*/0, /*trans_a=*/false, /*trans_b=*/false, + /*a_is_packed=*/false, /*b_is_packed=*/false, _activation_type); + gemm_awq_op(input, *weight, *qscale, *_qzero, output, bias); + } + break; + case models::QUANTIZATION_TYPE::AWQ_GEMV: + { + ops::GemvAwq gemv_awq_op(/*alpha=*/1, /*beta=*/0, /*trans_a=*/false, /*trans_b=*/false, + /*a_is_packed=*/false, /*b_is_packed=*/false, _activation_type); + gemv_awq_op(input, *weight, *qscale, *_qzero, output, bias); + break; + } + default: + throw std::invalid_argument("Dense forward: invalid quantized type," + "support only ct2 and awq quantization"); + } } else { _gemm_op(input, *weight, output, nullptr, bias); } diff --git a/src/models/model.cc b/src/models/model.cc index dd7273d57..b8e1c2d8f 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -173,7 +173,7 @@ namespace ctranslate2 { _device_index = index; } - void Model::set_compute_type(ComputeType type, Device device, int device_index) { + void Model::set_compute_type(ComputeType type, Device device, int device_index, bool update_weight) { if (_device != Device::CPU) throw std::runtime_error("set_compute_type expects the variables to be on CPU"); @@ -187,46 +187,47 @@ namespace ctranslate2 { device, device_index); - DataType weight_dtype = DataType::FLOAT32; - DataType float_dtype = DataType::FLOAT32; - std::tie(weight_dtype, float_dtype) = compute_type_to_data_type(_effective_compute_type); - if (_use_flash_attention && (float_dtype != DataType::FLOAT16 && float_dtype != DataType::BFLOAT16)) - throw std::runtime_error("FlashAttention only support fp16 and bf16 data type"); - - const auto variable_index = _variable_index; - for (auto& variable_pair : variable_index) { - const auto& name = variable_pair.first; - auto& variable = *variable_pair.second; - - // Convert "weight" variables to the expected compute type. - // Other float variables (e.g. biases) may be converted to another float type. - if (is_quantizable(name)) { - auto variable_weight_dtype = weight_dtype; - // For conv layer, we need to reshape to ensure dtype as its weights are 3D. - auto is_conv = name.find("conv") != std::string::npos; - auto kernel_size = -1; - if (is_conv) { - kernel_size = variable.dim(2); - variable.reshape({variable.dim(0), variable.dim(1) * variable.dim(2)}); - // For CUDA and DNNL backend, quantized convolution is not supported. Hence, convert to float_dtype. - if (device == Device::CUDA - #ifdef CT2_WITH_DNNL - || true - #endif - ) { - variable_weight_dtype = float_dtype; + if (update_weight) { + DataType weight_dtype = DataType::FLOAT32; + DataType float_dtype = DataType::FLOAT32; + std::tie(weight_dtype, float_dtype) = compute_type_to_data_type(_effective_compute_type); + if (_use_flash_attention && (float_dtype != DataType::FLOAT16 && float_dtype != DataType::BFLOAT16)) + throw std::runtime_error("FlashAttention only support fp16 and bf16 data type"); + + const auto variable_index = _variable_index; + for (auto& variable_pair : variable_index) { + const auto &name = variable_pair.first; + auto &variable = *variable_pair.second; + + // Convert "weight" variables to the expected compute type. + // Other float variables (e.g. biases) may be converted to another float type. + if (is_quantizable(name)) { + auto variable_weight_dtype = weight_dtype; + // For conv layer, we need to reshape to ensure dtype as its weights are 3D. + auto is_conv = name.find("conv") != std::string::npos; + auto kernel_size = -1; + if (is_conv) { + kernel_size = variable.dim(2); + variable.reshape({variable.dim(0), variable.dim(1) * variable.dim(2)}); + // For CUDA and DNNL backend, quantized convolution is not supported. Hence, convert to float_dtype. + if (device == Device::CUDA +#ifdef CT2_WITH_DNNL + || true +#endif + ) { + variable_weight_dtype = float_dtype; + } } - } - ensure_dtype(name, variable, variable_weight_dtype); - // Undo reshape for conv weights - if (is_conv) { - variable.reshape({variable.dim(0), variable.dim(1) / kernel_size, kernel_size}); - } + ensure_dtype(name, variable, variable_weight_dtype); + // Undo reshape for conv weights + if (is_conv) { + variable.reshape({variable.dim(0), variable.dim(1) / kernel_size, kernel_size}); + } + } else if (is_convertible(variable, name) + && is_float_type(variable.dtype()) + && variable.dtype() != float_dtype) + variable = variable.to(float_dtype); } - else if (is_convertible(variable, name) - && is_float_type(variable.dtype()) - && variable.dtype() != float_dtype) - variable = variable.to(float_dtype); } } @@ -637,6 +638,10 @@ namespace ctranslate2 { " the config.json could lead to error! Try using the latest version of converters"); } + QUANTIZATION_TYPE quantization_type = QUANTIZATION_TYPE::CT2; + if (model->config.contains("quantization_type")) + model->set_quant_method(model->config["quantization_type"]); + for (uint32_t i = 0; i < num_variables; ++i) { auto name = consume(model_file); const size_t rank = consume(model_file); @@ -740,12 +745,24 @@ namespace ctranslate2 { variable = std::move(outputs[current_index]); } } - model->register_variable(std::move(name), std::move(variable)); } // Maybe quantize/dequantize/convert the variables to match the requested compute type. - model->set_compute_type(compute_type, device, device_index); + // if model is quantized with a specific type different with CT2, it use the specific kernel + // So have to keep the compute type for it. + switch (model->quant_method()) { + case QUANTIZATION_TYPE::CT2: + model->set_compute_type(compute_type, device, device_index); + break; + case QUANTIZATION_TYPE::AWQ_GEMM: + case QUANTIZATION_TYPE::AWQ_GEMV: + model->set_compute_type(ComputeType::FLOAT16, device, device_index, false); + break; + default: + throw std::invalid_argument("Quantization type is not supported"); + break; + } // Move variables to the target device. model->set_device(device, device_index); @@ -759,6 +776,7 @@ namespace ctranslate2 { model->register_variable_alias(alias, variable_name); // Also alias the quantization scale that could be associated to variable_name. model->register_variable_alias(alias + "_scale", variable_name + "_scale"); + model->register_variable_alias(alias + "_zero", variable_name + "_zero"); } } diff --git a/src/ops/awq/dequantize.cc b/src/ops/awq/dequantize.cc new file mode 100644 index 000000000..eb909e087 --- /dev/null +++ b/src/ops/awq/dequantize.cc @@ -0,0 +1,24 @@ +#include + +#include "dispatch.h" + +namespace ctranslate2 { + namespace ops { + + DequantizeAwq::DequantizeAwq() = default; + + void DequantizeAwq::operator()(const StorageView& input, + const StorageView& scale, + const StorageView& zeros, + StorageView& output) const{ + PROFILE("Dequantize Awq"); + + if (input.dtype() != DataType::INT32 && output.dtype() != DataType::FLOAT16) + throw std::invalid_argument("Awq dequantization is only supported for int32 input and float16 output"); + if (input.device() == Device::CPU) + throw std::invalid_argument("Awq dequantization is only supported on GPU"); + + DEVICE_DISPATCH(input.device(), (dequantize(input, scale, zeros, output))); + } + } +} diff --git a/src/ops/awq/dequantize.cuh b/src/ops/awq/dequantize.cuh new file mode 100644 index 000000000..799f68c3d --- /dev/null +++ b/src/ops/awq/dequantize.cuh @@ -0,0 +1,79 @@ +#pragma once +#include + +namespace ctranslate2 { + namespace ops { + __device__ __forceinline__ int make_divisible(int c, int divisor){ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + return (c + divisor - 1) / divisor; +#endif + } + + __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) + { + uint4 result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + + return result; + } + } +} \ No newline at end of file diff --git a/src/ops/awq/dequantize_cpu.cc b/src/ops/awq/dequantize_cpu.cc new file mode 100644 index 000000000..76cf4c3a9 --- /dev/null +++ b/src/ops/awq/dequantize_cpu.cc @@ -0,0 +1,23 @@ +#include + +namespace ctranslate2 { + namespace ops { + template + void DequantizeAwq::dequantize(const StorageView&, + const StorageView&, + const StorageView&, + StorageView&) const { + throw std::runtime_error("AWQ dequantize is not applied for the cpu"); + } + +#define DECLARE_IMPL(T) \ + template void \ + DequantizeAwq::dequantize( \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float16_t) + } +} diff --git a/src/ops/awq/dequantize_gpu.cu b/src/ops/awq/dequantize_gpu.cu new file mode 100644 index 000000000..f63361035 --- /dev/null +++ b/src/ops/awq/dequantize_gpu.cu @@ -0,0 +1,115 @@ +#include +#include "dequantize.cuh" +#include "cuda/helpers.h" + +namespace ctranslate2 { + namespace ops { + + __global__ void __launch_bounds__(64) dequantize_weights(const int* __restrict__ B, // 4096x64 4096 rows 64 cols + const half * __restrict__ scaling_factors, // 32x512 32 rows 512 cols + const int* __restrict__ zeros, // 32x64 32 rows 64 cols + half * __restrict__ C, // 4096x512 4096 rows 512 cols + int G, + int in_c, + int out_c) + { + if (blockIdx.z > 0) { + B = B + blockIdx.z * in_c * out_c / 8; + scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G; + zeros = zeros + blockIdx.z * in_c * out_c / G / 8; + C = C + blockIdx.z * in_c * out_c; + } + static constexpr uint32_t ZERO = 0x0; + half B_shared[32 * (128 + 8)]; + + half* B_shared_ptr2 = B_shared; + + int N = blockDim.x * gridDim.x; // 2 + int col = (blockIdx.x * blockDim.x + threadIdx.x); + int row = blockIdx.y * blockDim.y + threadIdx.y; + int index1 = 8 * col + 8 * row * N; // + i (<8) + half* C_ptr2 = C + index1; + + int index2 = col + row * N; + const int* B_ptr2 = B + index2; + + int index3 = col + (int)(row / G) * N; + const int* zeros_ptr2 = zeros + index3; + int index4 = 8 * col + (int)(row / G) * N * 8; // + i (<8) + const half* scaling_factors_ptr2 = scaling_factors + index4; + + + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); + int j=0; + + uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + + *(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16; + + for (int i=0; i<8; ++i) { + *(C_ptr2 + i) = B_shared[i]; + } + } + + template + void DequantizeAwq::dequantize(const StorageView& input, + const StorageView& scale, + const StorageView& zero, + StorageView& output) const { + dim_t in_c = input.rank() == 2 ? input.dim(0) : input.dim(1); + dim_t qout_c = input.rank() == 2 ? input.dim(1) : input.dim(2); + int num_experts = input.rank() == 2 ? 1 : input.dim(0); + int out_c = qout_c * 8; + int G = in_c / (input.rank() == 2 ? scale.dim(0) : scale.dim(1)); + + int x_thread = 0 /*thx*/; + int y_thread = 0 /*thy*/; + + int x_blocks = 1; + int y_blocks = 1; + x_thread = qout_c; + y_thread = in_c; + + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + if (num_experts == 1) { + output.resize({in_c, out_c}); + } else { + output.resize({num_experts, in_c, out_c}); + } + + auto output_data = reinterpret_cast(output.data()); + const auto scale_data = reinterpret_cast(scale.data()); + + dim3 num_blocks(x_blocks, y_blocks, num_experts); + dim3 threads_per_block(x_thread, y_thread); // col, row 64x4096 + + dequantize_weights<<>>(input.data(), scale_data, + zero.data(), output_data, G, in_c, out_c); + } + +#define DECLARE_IMPL(T) \ + template void \ + DequantizeAwq::dequantize( \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float16_t) + + } +} diff --git a/src/ops/awq/gemm.cc b/src/ops/awq/gemm.cc new file mode 100644 index 000000000..bb3834213 --- /dev/null +++ b/src/ops/awq/gemm.cc @@ -0,0 +1,34 @@ +#include +#include +#include "dispatch.h" + +namespace ctranslate2 { + namespace ops { + + void GemmAwq::operator()(const StorageView& a, + const StorageView& b, + const StorageView& scale, + const StorageView& zero, + StorageView& c, + const StorageView* bias) const { + PROFILE("Gemm Awq"); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + throw std::runtime_error("AWQ Gemm does not support for cuda arch < 7.5"); +#else + if (a.dtype() != DataType::FLOAT16 && b.dtype() != DataType::INT32) + throw std::invalid_argument("Awq gemm is only supported for float16 input and int32 weight"); + if (a.device() == Device::CPU) + throw std::invalid_argument("Awq gemm is only supported on GPU"); + + DEVICE_DISPATCH(a.device(), (compute(a, b, scale, zero, c))); + + StorageView tmp(c.dtype(), c.device()); + ops::Sum(0)(c, tmp); + tmp.squeeze(0); + c = std::move(tmp); + + apply_bias_and_activation(c, bias, _activation_type); +#endif + } + } +} diff --git a/src/ops/awq/gemm_cpu.cc b/src/ops/awq/gemm_cpu.cc new file mode 100644 index 000000000..010385659 --- /dev/null +++ b/src/ops/awq/gemm_cpu.cc @@ -0,0 +1,25 @@ +#include + +namespace ctranslate2 { + namespace ops { + template + void GemmAwq::compute(const StorageView&, + const StorageView&, + const StorageView&, + const StorageView&, + StorageView&) const { + throw std::runtime_error("AWQ gemm is not applied for the cpu"); + } + +#define DECLARE_IMPL(T) \ + template void \ + GemmAwq::compute( \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float16_t) + } +} diff --git a/src/ops/awq/gemm_gpu.cu b/src/ops/awq/gemm_gpu.cu new file mode 100644 index 000000000..544ce8559 --- /dev/null +++ b/src/ops/awq/gemm_gpu.cu @@ -0,0 +1,543 @@ +#include "cuda/utils.h" +#include "dequantize.cuh" +#include +#include + +namespace ctranslate2 { + namespace ops { + __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, const half* __restrict__ A, const int* __restrict__ B, const half* __restrict__ scaling_factors, const int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) + { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + static constexpr uint32_t ZERO = 0x0; + float C_warp[32]; + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (128 + 8)]; + + int j_factors1 = ((OC + 128 - 1) / 128); + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); + + half A_shared_warp[8]; + half B_shared_warp[32]; + for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[(j_0_4_init * 8) + i] = 0.0; + } + } + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / 128; + bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + // bool wb_C_flag = (threadIdx.x / 4) < M; + + const half* A_ptr = A + + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + const int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * 2 + + (((int)threadIdx.x) / (128 / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (128 / 8) + + (((int)threadIdx.x) % (128 / 8)) * 1; +// Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8) + + (((int)threadIdx.x) / (128 / 8)) * (128 + 8) + + (((int)threadIdx.x) % (128 / 8)) * 8; + + const int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (128 / 8) + + ((int)threadIdx.x) % (128 / 8); + + const half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * (128) + + (((int)threadIdx.x) % (128 / 8)) * 8; + + half* C_ptr = C + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * 128 + + ((int)threadIdx.y) * 64 + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + if (ld_A_flag) + { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); + } + else + { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + /* + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ + printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + } + */ + // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); + const int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) { + + // B: 32 x 136 (128+8) float16 + // each warp: 32 x 4 + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // - zero and * scale + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + /* + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ + printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + } + */ + + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16; + } + __syncthreads(); + + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { + { + unsigned int addr; + asm volatile( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + + + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); + } + + for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) { + { + unsigned int addr; + asm volatile( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } +#else + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } +#endif + } + } + } + +// TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) + { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + } + } + } +#endif + } + + + __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, const half* __restrict__ A, const int* __restrict__ B, const half* __restrict__ const scaling_factors, const int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) + { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + static constexpr uint32_t ZERO = 0x0; + float C_warp[32]; + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (64 + 8)]; + + __shared__ half scaling_factors_shared[64]; + __shared__ half zeros_shared[64]; + + int j_factors1 = ((OC + 64 - 1) / 64); + + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); + + half A_shared_warp[8]; + half B_shared_warp[16]; + for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[(j_0_4_init * 8) + i] = 0.0; + } + } + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / 64; + bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + // bool wb_C_flag = (threadIdx.x / 4) < M; + + const half* A_ptr = A + + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + const int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * 4 + + (((int)threadIdx.x) / (64 / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (64 / 8) + + (((int)threadIdx.x) % (64 / 8)) * 1; +// Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8) + + (((int)threadIdx.x) / (64 / 8)) * (64 + 8) + + (((int)threadIdx.x) % (64 / 8)) * 8; + + const int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (64 / 8) + + ((int)threadIdx.x) % (64 / 8); + + const half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * (64) + + (((int)threadIdx.x) % (64 / 8)) * 8; + + half* C_ptr = C + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * 64 + + ((int)threadIdx.y) * 32 + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + if (ld_A_flag) + { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); + } + else + { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + /* + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ + printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + } + */ + // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); + const int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) { + + // B: 32 x 136 (128+8) float16 + // each warp: 32 x 4 + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // - zero and * scale + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + /* + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ + printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + } + */ + + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16; + } + __syncthreads(); + + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) + { + { + unsigned int addr; + asm volatile( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); + } + + + for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) + { + { + unsigned int addr; + asm volatile( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + + for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) + { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } +#else + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } +#endif + } + } + } + +// TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) + { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + } + } + } +#endif + } + + template + void GemmAwq::compute(const StorageView& a, + const StorageView& b, + const StorageView& scale, + const StorageView& zero, + StorageView& c) const { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + dim_t num_in_channels = a.dim(-1); + dim_t num_in_feats = a.size() / num_in_channels; + dim_t split_k_iters = 8; + + if (a.rank() == 2) + c.resize({split_k_iters, num_in_feats, b.dim(1) * 8}); + else if (a.rank() == 3) + c.resize({split_k_iters, a.dim(0), a.dim(1), b.dim(1) * 8}); + + dim_t num_out_feats = num_in_feats; + dim_t num_out_channels = c.dim(-1); + + const auto a_data = reinterpret_cast(a.data()); + const auto b_data = reinterpret_cast(b.data()); + auto output_data = reinterpret_cast(c.data()); + const auto scale_data = reinterpret_cast(scale.data()); + const auto zero_data = reinterpret_cast(zero.data()); + dim_t group_size = num_in_channels / scale.dim(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + if (num_out_channels % 128 == 0) + { + dim_t j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + gemm_forward_4bit_cuda_m16n128k32<<>>( + group_size, split_k_iters, a_data, b_data, scale_data, zero_data, num_in_feats, num_in_channels, num_out_channels, output_data); + } + else if (num_out_channels % 64 == 0) + { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + gemm_forward_4bit_cuda_m16n64k32<<>>( + group_size, split_k_iters, a_data, b_data, scale_data, zero_data, num_in_feats, num_in_channels, num_out_channels, output_data); + } +#endif + } + + +#define DECLARE_IMPL(T) \ + template void \ + GemmAwq::compute( \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float16_t) + } +} \ No newline at end of file diff --git a/src/ops/awq/gemv.cc b/src/ops/awq/gemv.cc new file mode 100644 index 000000000..0245ac1af --- /dev/null +++ b/src/ops/awq/gemv.cc @@ -0,0 +1,39 @@ +#include +#include +#include + +#include "dispatch.h" + +namespace ctranslate2 { + namespace ops { + + void GemvAwq::operator()(const StorageView& a, + const StorageView& b, + const StorageView& scale, + const StorageView& zero, + StorageView& c, + const StorageView* bias) const { + PROFILE("Gemv Awq"); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + throw std::runtime_error("AWQ Gemv does not support for cuda arch < 7.5"); +#else + if (a.dtype() != DataType::FLOAT16 && b.dtype() != DataType::INT32) + throw std::invalid_argument("Awq gemm is only supported for float16 input and int32 weight"); + if (a.device() == Device::CPU) + throw std::invalid_argument("Awq gemm is only supported on GPU"); + + if (a.dim(0) > 8) { + DEVICE_DISPATCH(a.device(), (compute_gemv2(a, b, scale, zero, c))); + StorageView tmp(c.dtype(), c.device()); + ops::Sum(0)(c, tmp); + tmp.squeeze(0); + c = std::move(tmp); + } + else + DEVICE_DISPATCH(a.device(), (compute_gemv(a, b, scale, zero, c))); + + apply_bias_and_activation(c, bias, _activation_type); +#endif + } + } +} diff --git a/src/ops/awq/gemv_cpu.cc b/src/ops/awq/gemv_cpu.cc new file mode 100644 index 000000000..77603da11 --- /dev/null +++ b/src/ops/awq/gemv_cpu.cc @@ -0,0 +1,40 @@ +#include + +namespace ctranslate2 { + namespace ops { + template + void GemvAwq::compute_gemv(const StorageView&, + const StorageView&, + const StorageView&, + const StorageView&, + StorageView&) const { + throw std::runtime_error("AWQ gemv is not applied for the cpu"); + } + template + void GemvAwq::compute_gemv2(const StorageView&, + const StorageView&, + const StorageView&, + const StorageView&, + StorageView&) const { + throw std::runtime_error("AWQ gemv2 is not applied for the cpu"); + } + +#define DECLARE_IMPL(T) \ + template void \ + GemvAwq::compute_gemv2( \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; \ + template void \ + GemvAwq::compute_gemv( \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float16_t) + } +} diff --git a/src/ops/awq/gemv_gpu.cu b/src/ops/awq/gemv_gpu.cu new file mode 100644 index 000000000..a24b92773 --- /dev/null +++ b/src/ops/awq/gemv_gpu.cu @@ -0,0 +1,548 @@ +#include "cuda/utils.h" +#include "dequantize.cuh" +#include +#include +#define PACK_FACTOR 8 +#define WARP_SIZE 32 + +namespace ctranslate2 { + namespace ops { + template + __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int split_k_iters, const half* __restrict__ A, const int* __restrict__ B, const half* __restrict__ scaling_factors, const int* zeros, int M, int IC, int OC, half* __restrict__ C) + { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + static constexpr uint32_t ZERO = 0x0; + float C_warp[64]; + __shared__ half A_shared[128 * (32 + 8)]; + __shared__ half B_shared[64 * (32 + 8)]; + + // __shared__ half scaling_factors_shared[64]; + // __shared__ half zeros_shared[64]; + + int j_factors1 = ((OC + 64 - 1) / 64); + + //int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((M + 128 - 1) / 128 * j_factors1); + int blockIdx_z = blockIdx.x / ((M + 128 - 1) / 128 * j_factors1); + + half A_shared_warp[32]; + half B_shared_warp[16]; + for (int i_0_3_init = 0; i_0_3_init < 4; ++i_0_3_init) { + for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[((i_0_3_init * 16) + (j_0_4_init * 8)) + i] = 0.0; + } + } + } + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride_A = 4 * 32 * 8 / 32; + static constexpr int row_stride = 4 * 32 * 8 / 32; + const int make_divisible_multipler = 128 / G; + const int zeros_w = make_divisible(make_divisible(IC / G, 8), make_divisible_multipler) * make_divisible_multipler; + const int sf_w = zeros_w * 8; + + int ld_A_row = (blockIdx_y / j_factors1 * 128 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); // threadIdx.y is warp_id + // bool wb_C_flag = (threadIdx.x / 4) < M; + + const half* A_ptr = A + + (((int)blockIdx_y) / j_factors1 * 128 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + const int* B_ptr = B + + ((int)threadIdx.y) * (IC / 8) * 8 + + (((int)threadIdx.x) / (32 / 8)) * (IC / 8) + + (((int)blockIdx_y) % j_factors1) * 64 * (IC / 8) + + (((int)threadIdx.x) % (32 / 8)) * 1; + +// Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 4) * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8)) * 8; + + + const int* zeros_ptr = zeros + + ((int)threadIdx.y) * zeros_w * 8 + + (((int)threadIdx.x) / (32 / 8)) * zeros_w + + (((int)blockIdx_y) % j_factors1) * 64 * zeros_w + // this term is zero + + (((int)threadIdx.x) % (32 / 8)) / G ; + + const half* scaling_factors_ptr = scaling_factors + + ((int)threadIdx.y) * sf_w * 8 + + (((int)threadIdx.x) / (32 / 8)) * sf_w + + (((int)blockIdx_y) % j_factors1) * (64) * sf_w + // this term is zero + + (((int)threadIdx.x) % (32 / 8)) * 8 / G; + + + // Haotian: TBD, check, May 29 11:46 AM PST + half* C_ptr = C + + static_cast(blockIdx_z) * M * OC // blockIdx_z -> split_k dim + + (((int)blockIdx_y) % j_factors1) * 64 + + (((int)threadIdx.y) / 2) * 32 + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = make_divisible(IC / 32, split_k_iters); // (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * 32 + blockIdx_z >= IC) k_bound -= 1; + + // TODO (Haotian): load scales and zero points to smem + + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: Here we assume M % cta_M = 0. + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) + { + if (ld_A_row + ax0_ax1_fused_0 * row_stride_A < M) + { + *(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = *(uint4*)(A_ptr + (ax0_ax1_fused_0 * row_stride_A * IC) + (k_0_0 * 32)); + } + else + { + *(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = make_uint4(0, 0, 0, 0); + } + } + + + const int* zeros_ptr_local = zeros_ptr + k_0_0 * 32 / G / 8; + const half* scaling_factors_ptr_local = scaling_factors_ptr + k_0_0 * 32 / G; + + // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); + const int* B_ptr_local = B_ptr + k_0_0 * (32 / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { + + // B: 32 x 136 (128+8) float16 + // each warp: 32 x 4 + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + int B_loaded_current = *(B_ptr_local + ax0_ax1_fused_0 * row_stride * (IC / 8)); + int zeros_loaded = *(zeros_ptr_local + ax0_ax1_fused_0 * row_stride * zeros_w); + zeros_loaded >>= ((k_0_0 * 32 / G) % 8) * 4; + float current_zeros = (float)(zeros_loaded & 0xF); + half scaling_factors_loaded = *(scaling_factors_ptr_local + ax0_ax1_fused_0 * row_stride * sf_w); + half B_loaded_fp16[8]; +#pragma unroll + for (int ic_1 = 0; ic_1 < 8; ic_1++){ + float current_single_weight_fp = (float)(B_loaded_current & 0xF); + half dequantized_weight = __float2half(__half2float(scaling_factors_loaded) * (current_single_weight_fp - current_zeros)); + B_loaded_current = B_loaded_current >> 4; + B_loaded_fp16[ic_1] = dequantized_weight; + } + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (32 + 8)) = *reinterpret_cast(B_loaded_fp16); + } + __syncthreads(); + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { + for (int ax0_0 = 0; ax0_0 < 4; ++ax0_0) { + { + unsigned int addr; + asm volatile( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[((((((int)threadIdx.y) & 1) * 2560) + (ax0_0 * 640)) + (k_0_1 * 16))])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[3]) + : "r"(addr) + ); + } + } + + for (int ax0_0_1 = 0; ax0_0_1 < 2; ++ax0_0_1) { + { + unsigned int addr; + asm volatile( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[((((((int)threadIdx.y) >> 1) * 1280) + (ax0_0_1 * 640)) + (k_0_1 * 16))])) + ((((((int)threadIdx.x) >> 4) * 320) + ((((int)threadIdx.x) & 7) * 40)) + (((((int)threadIdx.x) & 15) >> 3) * 8)))) + ); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[3]) + : "r"(addr) + ); + } + } + + for (int i_0_3 = 0; i_0_3 < 4; ++i_0_3) { + for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]) + : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8 + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]) + : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8 + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3])); + } +#else + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]) + : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])); + } + + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])); + } +#endif + } + } + } + } + +// Haotian: Here (May 29 11:46AM PST) +// TODO: Shang: Hoist loop invariance. + for (int ax0_0_2 = 0; ax0_0_2 < 4; ++ax0_0_2) { + for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = (((int)blockIdx_y) / j_factors1) * 128 + (threadIdx.y % 2) * 64 + ax0_0_2 * 16 + (local_id % 4) / 2 * 8 + ((int)threadIdx.x) / 4; + if (row_offset < M) + { + *(C_ptr + ax1_0 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax0_0_2 * 16) + (ax1_0 * 8) + local_id]); + } + } + } + } +#endif + } + + // Reduce sum within the warp using the tree reduction algorithm. + __device__ __forceinline__ float warp_reduce_sum(float sum) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else +#pragma unroll + for(int i = 4; i >= 0; i--){ + sum += __shfl_down_sync(0xffffffff, sum, 1<(zeros + oc_idx * zeros_w + packed_group_idx * 2); + uint32_t packed_weights[4]; + // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) + *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); + // load scaling factors + // g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups. + float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]); + float current_zeros = (float)((packed_zeros >> (threadIdx.x / 2 * 4)) & 0xF); + int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; + const float4* inputs_ptr = inputs + inputs_ptr_delta; + // multiply 32 weights with 32 inputs +#pragma unroll + for (int ic_0 = 0; ic_0 < 4; ic_0++){ + // iterate over different uint32_t packed_weights in this loop + uint32_t current_packed_weight = packed_weights[ic_0]; + half packed_inputs[PACK_FACTOR]; + // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) + if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { + *((float4*)packed_inputs) = *(inputs_ptr + ic_0); +#pragma unroll + for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ + // iterate over 8 numbers packed within each uint32_t number + float current_single_weight_fp = (float)(current_packed_weight & 0xF); + float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); + //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); + psum += dequantized_weight * __half2float(packed_inputs[ic_1]); + current_packed_weight = current_packed_weight >> 4; + } + } + } + } + psum = warp_reduce_sum(psum); + if (threadIdx.x == 0) { + outputs[oc_idx] = __float2half(psum); + } +#endif + } + + +/* +Computes GEMV (group_size = 128). + +Args: + inputs: vector of shape [batch_size, IC]; + weight: matrix of shape [OC, IC / 8]; + output: vector of shape [OC]; + zeros: matrix of shape [OC, IC / group_size / 8]; + scaling_factors: matrix of shape [OC, IC / group_size]; + +Notes: + One cannot infer group_size from the shape of scaling factors. + the second dimension is rounded up to a multiple of PACK_FACTOR. +*/ + __global__ void gemv_kernel_g128( + const float4* _inputs, const uint32_t* weight, const uint32_t* zeros, const half* scaling_factors, half* _outputs, + const int IC, const int OC){ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + const int group_size = 128; + float psum = 0; + const int batch_idx = blockIdx.z; + const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y; + const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR; + half* outputs = _outputs + batch_idx * OC; + const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR); + const int weight_w = IC / PACK_FACTOR; + // TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address + const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR); + // consistent with input shape + const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR; + //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w); + // tile size: 4 OC x 1024 IC per iter + for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){ + // 1024 numbers in one iteration across warp. Need 1024 / group_size zeros. + uint32_t packed_zeros = *(zeros + oc_idx * zeros_w + packed_group_idx); + uint32_t packed_weights[4]; + // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) + *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); + // load scaling factors + // g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups. + float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]); + float current_zeros = (float)((packed_zeros >> (threadIdx.x / 4 * 4)) & 0xF); + int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; + const float4* inputs_ptr = inputs + inputs_ptr_delta; + // multiply 32 weights with 32 inputs +#pragma unroll + for (int ic_0 = 0; ic_0 < 4; ic_0++){ + // iterate over different uint32_t packed_weights in this loop + uint32_t current_packed_weight = packed_weights[ic_0]; + half packed_inputs[PACK_FACTOR]; + // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) + if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { + *((float4*)packed_inputs) = *(inputs_ptr + ic_0); +#pragma unroll + for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ + // iterate over 8 numbers packed within each uint32_t number + float current_single_weight_fp = (float)(current_packed_weight & 0xF); + float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); + //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); + psum += dequantized_weight * __half2float(packed_inputs[ic_1]); + current_packed_weight = current_packed_weight >> 4; + } + } + } + } + psum = warp_reduce_sum(psum); + if (threadIdx.x == 0) { + outputs[oc_idx] = __float2half(psum); + } +#endif + } + + template + void GemvAwq::compute_gemv(const StorageView& a, + const StorageView& b, + const StorageView& scale, + const StorageView& zero, + StorageView& c) const { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + dim_t num_in_channels = a.dim(-1); + dim_t num_in_feats = a.size() / num_in_channels; + if (a.rank() == 2) + c.resize({num_in_feats, b.dim(0)}); + else if (a.rank() == 3) + c.resize({a.dim(0), a.dim(1), b.dim(0)}); + + const auto a_data = reinterpret_cast(a.data()); + const auto b_data = reinterpret_cast(b.data()); + auto output_data = reinterpret_cast(c.data()); + const auto scale_data = reinterpret_cast(scale.data()); + const auto zero_data = reinterpret_cast(zero.data()); + dim_t group_size = num_in_channels / scale.dim(-1); + + dim_t num_out_feats = num_in_feats; + dim_t num_out_channels = c.dim(-1); + dim3 num_blocks(1, num_out_channels / 4, num_out_feats); + dim3 num_threads(32, 4); + if (group_size == 64) + { + gemv_kernel_g64<<>>( + // pointers + a_data, b_data, zero_data, scale_data, output_data, + // constants + num_in_channels, num_out_channels + ); + } + else if (group_size == 128) + { + gemv_kernel_g128<<>>( + // pointers + a_data, b_data, zero_data, scale_data, output_data, + // constants + num_in_channels, num_out_channels + ); + } +#endif + } + + template + void GemvAwq::compute_gemv2(const StorageView& a, + const StorageView& b, + const StorageView& scale, + const StorageView& zero, + StorageView& c) const { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + assert(false); +#else + dim_t num_in_channels = a.dim(-1); + dim_t num_in_feats = a.size() / num_in_channels; + dim_t split_k_iters = 8; + + if (a.rank() == 2) + c.resize({split_k_iters, num_in_feats, b.dim(0)}); + else if (a.rank() == 3) + c.resize({split_k_iters, a.dim(0), a.dim(1), b.dim(0)}); + + dim_t num_out_feats = num_in_feats; + dim_t num_out_channels = c.dim(-1); + + const auto a_data = reinterpret_cast(a.data()); + const auto b_data = reinterpret_cast(b.data()); + auto output_data = reinterpret_cast(c.data()); + const auto scale_data = reinterpret_cast(scale.data()); + const auto zero_data = reinterpret_cast(zero.data()); + dim_t group_size = num_in_channels / scale.dim(-1); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks((num_out_feats + 128 - 1) / 128 * j_factors1 * split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 4); + if (group_size == 128) + { + gemmv2_forward_4bit_cuda_m128n64k32<128><<>>( + split_k_iters, a_data, b_data, scale_data, zero_data, num_in_feats, num_in_channels, num_out_channels, output_data); + } + else if (group_size == 64) + { + gemmv2_forward_4bit_cuda_m128n64k32<64><<>>( + split_k_iters, a_data, b_data, scale_data, zero_data, num_in_feats, num_in_channels, num_out_channels, output_data); + } + else + { + throw std::invalid_argument("Group size temporarily not supported."); + } +#endif + } + + +#define DECLARE_IMPL(T) \ + template void \ + GemvAwq::compute_gemv2( \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; \ + template void \ + GemvAwq::compute_gemv( \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; + + DECLARE_IMPL(float16_t) + } +} \ No newline at end of file diff --git a/src/ops/mean.cc b/src/ops/mean.cc index 042f8c169..afac4cd82 100644 --- a/src/ops/mean.cc +++ b/src/ops/mean.cc @@ -39,7 +39,7 @@ namespace ctranslate2 { inner_size *= input.dim(i); DEVICE_AND_FLOAT_DISPATCH("Mean", input.device(), input.dtype(), - (compute(input, outer_size, axis_size, inner_size, output))); + (compute(input, outer_size, axis_size, inner_size, false, output))); } } diff --git a/src/ops/mean_cpu.cc b/src/ops/mean_cpu.cc index be8fd9d90..a4edb12c8 100644 --- a/src/ops/mean_cpu.cc +++ b/src/ops/mean_cpu.cc @@ -11,6 +11,7 @@ namespace ctranslate2 { const dim_t outer_size, const dim_t axis_size, const dim_t inner_size, + const bool get_sum, StorageView& output) const { const auto* src = input.data(); auto* dst = output.data(); @@ -22,7 +23,9 @@ namespace ctranslate2 { for (dim_t k = 0; k < axis_size; ++k) { sum += src[i * axis_size * inner_size + k * inner_size + j]; } - dst[i * inner_size + j] = sum / float(axis_size); + dst[i * inner_size + j] = sum; + if (!get_sum) + dst[i * inner_size + j] /= float(axis_size); } } }); @@ -34,6 +37,7 @@ namespace ctranslate2 { const dim_t outer_size, \ const dim_t axis_size, \ const dim_t inner_size, \ + const bool get_sum, \ StorageView& output) const; DECLARE_IMPL(float) diff --git a/src/ops/mean_gpu.cu b/src/ops/mean_gpu.cu index 5125924c9..a57a679a6 100644 --- a/src/ops/mean_gpu.cu +++ b/src/ops/mean_gpu.cu @@ -15,6 +15,7 @@ namespace ctranslate2 { const cuda::index_t outer_size, const cuda::index_t axis_size, const cuda::index_t inner_size, + const bool get_sum, T* output) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -30,7 +31,9 @@ namespace ctranslate2 { AccumT sum = BlockReduce(temp_storage).Sum(thread_sum); if (threadIdx.x == 0) { - output[blockIdx.x] = sum / AccumT(axis_size); + output[blockIdx.x] = sum; + if (!get_sum) + output[blockIdx.x] /= AccumT(axis_size); } } @@ -39,6 +42,7 @@ namespace ctranslate2 { const dim_t outer_size, const dim_t axis_size, const dim_t inner_size, + const bool get_sum, StorageView& output) const { const dim_t blocks = std::min(outer_size * inner_size, cuda::max_blocks); mean_kernel, float><<>>( @@ -46,6 +50,7 @@ namespace ctranslate2 { outer_size, axis_size, inner_size, + get_sum, cuda::device_cast(output.data())); } @@ -55,6 +60,7 @@ namespace ctranslate2 { const dim_t outer_size, \ const dim_t axis_size, \ const dim_t inner_size, \ + const bool get_sum, \ StorageView& output) const; DECLARE_IMPL(float) diff --git a/src/ops/sum.cc b/src/ops/sum.cc new file mode 100644 index 000000000..bbbc5a8bc --- /dev/null +++ b/src/ops/sum.cc @@ -0,0 +1,44 @@ +#include "ctranslate2/ops/sum.h" + +#include "dispatch.h" +namespace ctranslate2 { + namespace ops { + + Sum::Sum(const dim_t axis) + : Mean(axis) + { + } + + void Sum::operator()(const StorageView& input, StorageView& output) const { + PROFILE("Sum"); + + const dim_t axis = _axis < 0 ? input.rank() + _axis : _axis; + if (axis >= input.rank()) + throw std::out_of_range("Cannot compute sum of axis " + std::to_string(axis) + + " for input with rank " + std::to_string(input.rank())); + + const dim_t axis_size = input.dim(axis); + if (axis_size == 1) { + output = input; + return; + } + + { + Shape output_shape(input.shape()); + output_shape[axis] = 1; + output.resize(std::move(output_shape)); + } + + dim_t inner_size = 1; + dim_t outer_size = 1; + for (dim_t i = 0; i < axis; ++i) + outer_size *= input.dim(i); + for (dim_t i = axis + 1; i < input.rank(); ++i) + inner_size *= input.dim(i); + + DEVICE_AND_FLOAT_DISPATCH("Sum", input.device(), input.dtype(), + (compute(input, outer_size, axis_size, inner_size, true, output))); + } + + } +}