From 39f48f2e843df52245e6c857326e1115bca12b03 Mon Sep 17 00:00:00 2001
From: Minh-Thuc <46375464+minhthuc2502@users.noreply.github.com>
Date: Thu, 4 Jul 2024 10:27:40 +0200
Subject: [PATCH] Quantzation AWQ GEMM + GEMV (#1727)
* quantzation awq gemm + gemv
* fix pipeline
* fix pipeline
* fix pipeline
* fix dequantize awq
* remove duplicated code
---
CMakeLists.txt | 10 +
README.md | 2 +-
docs/quantization.md | 19 +
include/ctranslate2/layers/common.h | 3 +
include/ctranslate2/models/model.h | 17 +-
include/ctranslate2/ops/awq/dequantize_awq.h | 26 +
include/ctranslate2/ops/awq/gemm.h | 27 +
include/ctranslate2/ops/awq/gemv.h | 33 ++
include/ctranslate2/ops/gemm.h | 3 +-
include/ctranslate2/ops/mean.h | 3 +-
include/ctranslate2/ops/ops.h | 4 +
include/ctranslate2/ops/sum.h | 17 +
python/ctranslate2/converters/transformers.py | 150 ++++-
python/ctranslate2/converters/utils.py | 19 +
python/ctranslate2/specs/common_spec.py | 9 +
python/ctranslate2/specs/transformer_spec.py | 34 +-
src/layers/common.cc | 35 ++
src/models/model.cc | 100 ++--
src/ops/awq/dequantize.cc | 24 +
src/ops/awq/dequantize.cuh | 79 +++
src/ops/awq/dequantize_cpu.cc | 23 +
src/ops/awq/dequantize_gpu.cu | 115 ++++
src/ops/awq/gemm.cc | 34 ++
src/ops/awq/gemm_cpu.cc | 25 +
src/ops/awq/gemm_gpu.cu | 543 +++++++++++++++++
src/ops/awq/gemv.cc | 39 ++
src/ops/awq/gemv_cpu.cc | 40 ++
src/ops/awq/gemv_gpu.cu | 548 ++++++++++++++++++
src/ops/mean.cc | 2 +-
src/ops/mean_cpu.cc | 6 +-
src/ops/mean_gpu.cu | 8 +-
src/ops/sum.cc | 44 ++
32 files changed, 1964 insertions(+), 77 deletions(-)
create mode 100644 include/ctranslate2/ops/awq/dequantize_awq.h
create mode 100644 include/ctranslate2/ops/awq/gemm.h
create mode 100644 include/ctranslate2/ops/awq/gemv.h
create mode 100644 include/ctranslate2/ops/sum.h
create mode 100644 src/ops/awq/dequantize.cc
create mode 100644 src/ops/awq/dequantize.cuh
create mode 100644 src/ops/awq/dequantize_cpu.cc
create mode 100644 src/ops/awq/dequantize_gpu.cu
create mode 100644 src/ops/awq/gemm.cc
create mode 100644 src/ops/awq/gemm_cpu.cc
create mode 100644 src/ops/awq/gemm_gpu.cu
create mode 100644 src/ops/awq/gemv.cc
create mode 100644 src/ops/awq/gemv_cpu.cc
create mode 100644 src/ops/awq/gemv_gpu.cu
create mode 100644 src/ops/sum.cc
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 18eeb826c..ac94aac57 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -191,6 +191,13 @@ set(SOURCES
src/ops/transpose.cc
src/ops/nccl_ops.cc
src/ops/nccl_ops_cpu.cc
+ src/ops/awq/dequantize.cc
+ src/ops/awq/dequantize_cpu.cc
+ src/ops/awq/gemm.cc
+ src/ops/awq/gemm_cpu.cc
+ src/ops/awq/gemv.cc
+ src/ops/awq/gemv_cpu.cc
+ src/ops/sum.cc
src/padder.cc
src/profiler.cc
src/random.cc
@@ -595,6 +602,9 @@ if (WITH_CUDA)
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
+ src/ops/awq/gemm_gpu.cu
+ src/ops/awq/gemv_gpu.cu
+ src/ops/awq/dequantize_gpu.cu
)
set_source_files_properties(
diff --git a/README.md b/README.md
index 7ce65486b..bfb64c851 100644
--- a/README.md
+++ b/README.md
@@ -26,7 +26,7 @@ The project is production-oriented and comes with [backward compatibility guaran
## Key features
* **Fast and efficient execution on CPU and GPU**
The execution [is significantly faster and requires less resources](#benchmarks) than general-purpose deep learning frameworks on supported models and tasks thanks to many advanced optimizations: layer fusion, padding removal, batch reordering, in-place operations, caching mechanism, etc.
-* **Quantization and reduced precision**
The model serialization and computation support weights with [reduced precision](https://opennmt.net/CTranslate2/quantization.html): 16-bit floating points (FP16), 16-bit brain floating points (BF16), 16-bit integers (INT16), and 8-bit integers (INT8).
+* **Quantization and reduced precision**
The model serialization and computation support weights with [reduced precision](https://opennmt.net/CTranslate2/quantization.html): 16-bit floating points (FP16), 16-bit brain floating points (BF16), 16-bit integers (INT16), 8-bit integers (INT8) and AWQ quantization (INT4).
* **Multiple CPU architectures support**
The project supports x86-64 and AArch64/ARM64 processors and integrates multiple backends that are optimized for these platforms: [Intel MKL](https://software.intel.com/content/www/us/en/develop/tools/oneapi/components/onemkl.html), [oneDNN](https://github.com/oneapi-src/oneDNN), [OpenBLAS](https://www.openblas.net/), [Ruy](https://github.com/google/ruy), and [Apple Accelerate](https://developer.apple.com/documentation/accelerate).
* **Automatic CPU detection and code dispatch**
One binary can include multiple backends (e.g. Intel MKL and oneDNN) and instruction set architectures (e.g. AVX, AVX2) that are automatically selected at runtime based on the CPU information.
* **Parallel and asynchronous execution**
Multiple batches can be processed in parallel and asynchronously using multiple GPUs or CPU cores.
diff --git a/docs/quantization.md b/docs/quantization.md
index aa1d247d4..296c57000 100644
--- a/docs/quantization.md
+++ b/docs/quantization.md
@@ -6,6 +6,7 @@ Quantization is a technique that can reduce the model size and accelerate its ex
* 16-bit integers (INT16)
* 16-bit floating points (FP16)
* 16-bit brain floating points (BF16)
+* 4-bit AWQ Quantization
```{tip}
See the benchmark results in the main [README](https://github.com/OpenNMT/CTranslate2#benchmarks) to compare the performance and memory usage with and without quantization.
@@ -161,3 +162,21 @@ In this mode, all model weights are stored in half precision and all layers are
* NVIDIA GPU with Compute Capability >= 8.0
In this mode, all model weights are stored in BF16 and all layers are run with this type.
+
+### 4-bit AWQ
+
+The compute type would be `int32_float16`
+
+**Supported on:**
+
+* NVIDIA GPU with Compute Capability >= 7.5
+
+In this mode, all model weights are stored in half precision and all layers are run in half precision. Other parameters like scale and zero are stored in ``int32``.
+
+For example,
+
+```bash
+ ct2-transformers-converter --model TheBloke/Llama-2-7B-AWQ --copy_files tokenizer.model --output_dir ct2_model
+```
+
+We have to quantize the model with AWQ first, then convert it to CT2 format.
\ No newline at end of file
diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h
index 7ea5e9126..3985b3feb 100644
--- a/include/ctranslate2/layers/common.h
+++ b/include/ctranslate2/layers/common.h
@@ -138,16 +138,19 @@ namespace ctranslate2 {
const StorageView& _weight;
const StorageView* _bias;
const StorageView* _qscale;
+ const StorageView* _qzero;
const StorageView* _u8_shift_compensation;
StorageView _partial_weight;
StorageView _partial_bias;
StorageView _partial_qscale;
StorageView _partial_u8_shift_compensation;
const DataType _output_type;
+ const models::QUANTIZATION_TYPE _quant_method;
const bool _quantized_gemm;
const ops::Gemm _gemm_op;
const ops::Quantize _quantize_op;
const ops::Dequantize _dequantize_op;
+ const ops::ActivationType* _activation_type;
const bool _is_layer_out;
};
diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h
index babba1455..32e4f8403 100644
--- a/include/ctranslate2/models/model.h
+++ b/include/ctranslate2/models/model.h
@@ -11,6 +11,12 @@
namespace ctranslate2 {
namespace models {
+ enum class QUANTIZATION_TYPE {
+ CT2,
+ AWQ_GEMM,
+ AWQ_GEMV
+ };
+
static const size_t current_binary_version = 6;
// Checks whether the provided path could contain a CTranslate2 model.
@@ -90,6 +96,14 @@ namespace ctranslate2 {
return _use_flash_attention;
}
+ QUANTIZATION_TYPE quant_method() const {
+ return _quant_method;
+ }
+
+ void set_quant_method(QUANTIZATION_TYPE type) {
+ _quant_method = type;
+ }
+
virtual bool use_global_int16_scale() const {
return true;
}
@@ -160,7 +174,7 @@ namespace ctranslate2 {
private:
void process_linear_weights();
- void set_compute_type(ComputeType type, Device device, int device_index);
+ void set_compute_type(ComputeType type, Device device, int device_index, bool update_weight=true);
void ensure_dtype(const std::string& name,
StorageView& variable,
const DataType target_dtype);
@@ -177,6 +191,7 @@ namespace ctranslate2 {
std::unordered_map> _variable_index;
bool _use_flash_attention = false;
bool _tensor_parallel = false;
+ QUANTIZATION_TYPE _quant_method = QUANTIZATION_TYPE::CT2;
};
template<>
diff --git a/include/ctranslate2/ops/awq/dequantize_awq.h b/include/ctranslate2/ops/awq/dequantize_awq.h
new file mode 100644
index 000000000..3b1a20f6a
--- /dev/null
+++ b/include/ctranslate2/ops/awq/dequantize_awq.h
@@ -0,0 +1,26 @@
+#pragma once
+
+#include "../op.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+ class DequantizeAwq : public Op {
+ public:
+ DequantizeAwq();
+
+ void operator()(const StorageView& input,
+ const StorageView& scale,
+ const StorageView& zeros,
+ StorageView& output) const;
+
+ private:
+ template
+ void dequantize(const StorageView& input,
+ const StorageView& scale,
+ const StorageView& zeros,
+ StorageView& output) const;
+ };
+
+ }
+}
diff --git a/include/ctranslate2/ops/awq/gemm.h b/include/ctranslate2/ops/awq/gemm.h
new file mode 100644
index 000000000..8233f12f7
--- /dev/null
+++ b/include/ctranslate2/ops/awq/gemm.h
@@ -0,0 +1,27 @@
+#pragma once
+
+#include "../activation.h"
+#include "../gemm.h"
+
+namespace ctranslate2 {
+ namespace ops {
+ class GemmAwq : public Gemm {
+ public:
+ using Gemm::Gemm;
+ void operator()(const StorageView& a,
+ const StorageView& b,
+ const StorageView& scale,
+ const StorageView& zero,
+ StorageView& c,
+ const StorageView* bias = nullptr) const;
+
+ private:
+ template
+ void compute(const StorageView& a,
+ const StorageView& b,
+ const StorageView& scale,
+ const StorageView& zero,
+ StorageView& c) const;
+ };
+ }
+}
\ No newline at end of file
diff --git a/include/ctranslate2/ops/awq/gemv.h b/include/ctranslate2/ops/awq/gemv.h
new file mode 100644
index 000000000..62866e515
--- /dev/null
+++ b/include/ctranslate2/ops/awq/gemv.h
@@ -0,0 +1,33 @@
+#pragma once
+
+#include "../activation.h"
+#include "../gemm.h"
+
+namespace ctranslate2 {
+ namespace ops {
+ class GemvAwq : public Gemm {
+ public:
+ using Gemm::Gemm;
+ void operator()(const StorageView& a,
+ const StorageView& b,
+ const StorageView& scale,
+ const StorageView& zero,
+ StorageView& c,
+ const StorageView* bias = nullptr) const;
+
+ private:
+ template
+ void compute_gemv(const StorageView& a,
+ const StorageView& b,
+ const StorageView& scale,
+ const StorageView& zero,
+ StorageView& c) const;
+ template
+ void compute_gemv2(const StorageView& a,
+ const StorageView& b,
+ const StorageView& scale,
+ const StorageView& zero,
+ StorageView& c) const;
+ };
+ }
+}
\ No newline at end of file
diff --git a/include/ctranslate2/ops/gemm.h b/include/ctranslate2/ops/gemm.h
index c309063d6..3c4efbb02 100644
--- a/include/ctranslate2/ops/gemm.h
+++ b/include/ctranslate2/ops/gemm.h
@@ -39,6 +39,8 @@ namespace ctranslate2 {
const dim_t k,
const dim_t n,
const float alpha);
+ protected:
+ const ActivationType* _activation_type;
private:
float _alpha;
@@ -47,7 +49,6 @@ namespace ctranslate2 {
bool _trans_b;
bool _a_is_packed;
bool _b_is_packed;
- const ActivationType* _activation_type;
template
void compute(const StorageView& a,
diff --git a/include/ctranslate2/ops/mean.h b/include/ctranslate2/ops/mean.h
index 251f342dc..501394c3b 100644
--- a/include/ctranslate2/ops/mean.h
+++ b/include/ctranslate2/ops/mean.h
@@ -11,12 +11,13 @@ namespace ctranslate2 {
void operator()(const StorageView& input, StorageView& output) const override;
- private:
+ protected:
template
void compute(const StorageView& input,
const dim_t outer_size,
const dim_t axis_size,
const dim_t inner_size,
+ const bool get_sum,
StorageView& output) const;
const dim_t _axis;
diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h
index ceca49450..ed9db4265 100644
--- a/include/ctranslate2/ops/ops.h
+++ b/include/ctranslate2/ops/ops.h
@@ -39,3 +39,7 @@
#include "slide.h"
#include "nccl_ops.h"
#include "flash_attention.h"
+#include "awq/gemm.h"
+#include "awq/gemv.h"
+#include "awq/dequantize_awq.h"
+#include "sum.h"
diff --git a/include/ctranslate2/ops/sum.h b/include/ctranslate2/ops/sum.h
new file mode 100644
index 000000000..3a240d8fc
--- /dev/null
+++ b/include/ctranslate2/ops/sum.h
@@ -0,0 +1,17 @@
+#pragma once
+
+#include "op.h"
+#include "mean.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+ class Sum : public Mean {
+ public:
+ Sum(const dim_t axis);
+
+ void operator()(const StorageView& input, StorageView& output) const override;
+ };
+
+ }
+}
diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py
index 719983a3d..849123a50 100644
--- a/python/ctranslate2/converters/transformers.py
+++ b/python/ctranslate2/converters/transformers.py
@@ -43,6 +43,11 @@
"su": attention_spec.RotaryScalingType.Su,
}
+_SUPPORTED_QUANTIZATION = {
+ "gemm": common_spec.Quantization.AWQ_GEMM,
+ "gemv": common_spec.Quantization.AWQ_GEMV,
+}
+
_MODEL_LOADERS = {}
@@ -217,8 +222,14 @@ def set_layer_norm(self, spec, module):
spec.gamma = module.weight
spec.beta = module.bias
- def set_linear(self, spec, module):
- spec.weight = module.weight
+ def set_linear(self, spec, module, quant_type=common_spec.Quantization.CT2):
+ if quant_type == common_spec.Quantization.CT2:
+ spec.weight = module.weight
+ else:
+ spec.weight = module.qweight
+ spec.weight_scale = module.scales
+ spec.weight_zero = module.qzeros
+
if isinstance(module, transformers.Conv1D):
spec.weight = spec.weight.transpose(0, 1)
if module.bias is not None:
@@ -1407,6 +1418,26 @@ def get_model_spec(self, model):
rotary_scaling_type = None
rotary_scaling_factor = 1
+ quantization_config = getattr(model.config, "quantization_config", None)
+ if quantization_config:
+ if quantization_config.quant_method == "awq":
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
+ if quant_type is None:
+ raise NotImplementedError(
+ "Quantization type '%s' is not yet implemented. "
+ "The following Quantization types are currently supported: %s"
+ % (
+ quantization_config.quant_method,
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
+ )
+ )
+ quant_group_size = quantization_config.group_size
+ quant_bits = quantization_config.bits
+ else:
+ quant_type = common_spec.Quantization.CT2
+ quant_group_size = None
+ quant_bits = None
+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
@@ -1420,9 +1451,12 @@ def get_model_spec(self, model):
rotary_scaling_factor=rotary_scaling_factor,
rotary_base=getattr(model.config, "rope_theta", 10000),
num_heads_kv=num_heads_kv,
+ quant_type=quant_type,
+ quant_group_size=quant_group_size,
+ quant_bits=quant_bits,
)
- self.set_decoder(spec.decoder, model.model)
+ self.set_decoder(spec.decoder, model.model, quant_type)
self.set_linear(spec.decoder.projection, model.lm_head)
return spec
@@ -1451,7 +1485,7 @@ def set_config(self, config, model, tokenizer):
def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight
- def set_decoder(self, spec, module):
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
spec.scale_embeddings = False
self.set_embeddings(spec.embeddings, module.embed_tokens)
self.set_layer_norm(spec.layer_norm, module.norm)
@@ -1464,17 +1498,39 @@ def set_decoder(self, spec, module):
layer_spec.ffn.layer_norm, layer.post_attention_layernorm
)
- wq = layer.self_attn.q_proj.weight
- wk = layer.self_attn.k_proj.weight
- wv = layer.self_attn.v_proj.weight
- wo = layer.self_attn.o_proj.weight
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
+ self.set_linear(
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
+ )
+ self.set_linear(
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
+ )
+ self.set_linear(
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
+ )
- layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
- layer_spec.self_attention.linear[1].weight = wo
+ if quant_type == common_spec.Quantization.CT2:
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
+ else:
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
+ utils.fuse_linear_prequant(
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
+ )
+ self.set_linear(
+ layer_spec.self_attention.linear[1],
+ layer.self_attn.o_proj,
+ quant_type=quant_type,
+ )
- self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
- self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
- self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
+ self.set_linear(
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
+ )
+ self.set_linear(
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
+ )
+ self.set_linear(
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
+ )
delattr(layer, "self_attn")
delattr(layer, "mlp")
@@ -1512,6 +1568,26 @@ def get_model_spec(self, model):
rotary_scaling_type = None
rotary_scaling_factor = 1
+ quantization_config = getattr(model.config, "quantization_config", None)
+ if quantization_config:
+ if quantization_config.quant_method == "awq":
+ quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version)
+ if quant_type is None:
+ raise NotImplementedError(
+ "Quantization type '%s' is not yet implemented. "
+ "The following Quantization types are currently supported: %s"
+ % (
+ quantization_config.quant_method,
+ ", ".join(_SUPPORTED_QUANTIZATION.keys()),
+ )
+ )
+ quant_group_size = quantization_config.group_size
+ quant_bits = quantization_config.bits
+ else:
+ quant_type = common_spec.Quantization.CT2
+ quant_group_size = None
+ quant_bits = None
+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
@@ -1526,9 +1602,12 @@ def get_model_spec(self, model):
rotary_base=getattr(model.config, "rope_theta", 10000),
num_heads_kv=num_heads_kv,
sliding_window=sliding_window,
+ quant_type=quant_type,
+ quant_group_size=quant_group_size,
+ quant_bits=quant_bits,
)
- self.set_decoder(spec.decoder, model.model)
+ self.set_decoder(spec.decoder, model.model, quant_type=quant_type)
self.set_linear(spec.decoder.projection, model.lm_head)
return spec
@@ -1553,7 +1632,7 @@ def set_config(self, config, model, tokenizer):
def set_layer_norm(self, spec, layer_norm):
spec.gamma = layer_norm.weight
- def set_decoder(self, spec, module):
+ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
spec.scale_embeddings = False
self.set_embeddings(spec.embeddings, module.embed_tokens)
self.set_layer_norm(spec.layer_norm, module.norm)
@@ -1565,18 +1644,39 @@ def set_decoder(self, spec, module):
self.set_layer_norm(
layer_spec.ffn.layer_norm, layer.post_attention_layernorm
)
+ split_layers = [common_spec.LinearSpec() for _ in range(3)]
+ self.set_linear(
+ split_layers[0], layer.self_attn.q_proj, quant_type=quant_type
+ )
+ self.set_linear(
+ split_layers[1], layer.self_attn.k_proj, quant_type=quant_type
+ )
+ self.set_linear(
+ split_layers[2], layer.self_attn.v_proj, quant_type=quant_type
+ )
- wq = layer.self_attn.q_proj.weight
- wk = layer.self_attn.k_proj.weight
- wv = layer.self_attn.v_proj.weight
- wo = layer.self_attn.o_proj.weight
-
- layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
- layer_spec.self_attention.linear[1].weight = wo
+ if quant_type == common_spec.Quantization.CT2:
+ utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
+ else:
+ cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0
+ utils.fuse_linear_prequant(
+ layer_spec.self_attention.linear[0], split_layers, cc_dim
+ )
+ self.set_linear(
+ layer_spec.self_attention.linear[1],
+ layer.self_attn.o_proj,
+ quant_type=quant_type,
+ )
- self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
- self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
- self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
+ self.set_linear(
+ layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type
+ )
+ self.set_linear(
+ layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type
+ )
+ self.set_linear(
+ layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type
+ )
delattr(layer, "self_attn")
delattr(layer, "mlp")
diff --git a/python/ctranslate2/converters/utils.py b/python/ctranslate2/converters/utils.py
index 48d9d27a1..8ce9e937b 100644
--- a/python/ctranslate2/converters/utils.py
+++ b/python/ctranslate2/converters/utils.py
@@ -33,6 +33,25 @@ def fuse_linear(spec, layers):
)
+def fuse_linear_prequant(spec, layers, axis):
+ if not layers:
+ raise ValueError("Cannot fuse linear layers: at least one layer is required")
+ params = ["weight", "weight_scale", "weight_zero"]
+ if isinstance(layers[0].weight, np.ndarray):
+ concatenate = np.concatenate
+ else:
+ import torch
+
+ concatenate = torch.cat
+
+ for param in params:
+ setattr(
+ spec,
+ param,
+ concatenate([getattr(layer, param) for layer in layers], axis=axis),
+ )
+
+
def permute_for_sliced_rotary(weight, num_heads, rotary_dim=None):
"""Permutes the weight to use the sliced rotary implementation."""
if rotary_dim is not None:
diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py
index b517ef77c..b1162839c 100644
--- a/python/ctranslate2/specs/common_spec.py
+++ b/python/ctranslate2/specs/common_spec.py
@@ -23,6 +23,14 @@ class EmbeddingsMerge(enum.IntEnum):
ADD = 1
+class Quantization(enum.IntEnum):
+ """Activation type."""
+
+ CT2 = 0
+ AWQ_GEMM = 1
+ AWQ_GEMV = 2
+
+
class LayerNormSpec(model_spec.LayerSpec):
def __init__(self, rms_norm=False):
self.gamma = None
@@ -36,6 +44,7 @@ class LinearSpec(model_spec.LayerSpec):
def __init__(self):
self.weight = None
self.weight_scale = model_spec.OPTIONAL
+ self.weight_zero = model_spec.OPTIONAL
self.bias = model_spec.OPTIONAL
def has_bias(self):
diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py
index 2325f7bbf..abb812c8b 100644
--- a/python/ctranslate2/specs/transformer_spec.py
+++ b/python/ctranslate2/specs/transformer_spec.py
@@ -105,6 +105,9 @@ def __init__(
num_heads_kv: Optional[int] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
+ quant_type: Optional[common_spec.Quantization] = None,
+ quant_group_size: Optional[int] = None,
+ quant_bits: Optional[int] = None,
):
"""Initializes a Transformer decoder specification.
@@ -147,7 +150,12 @@ def __init__(
multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
num_heads_kv: Number of attention heads for the key and value.
sliding_window: Max sequence length to retain in KV Cache.
+ quant_type: quantization type used (like awq... for lower bit quantization)
+ quant_group_size: group size of the lower bit quantization
+ quant_bits: number of bit of the quantization (ex: 4bit)
"""
+
+ self._config = dict()
if parallel_residual:
if not pre_norm:
raise ValueError("The GPT-J block expects a pre-norm architecture")
@@ -215,7 +223,7 @@ def __init__(
for _ in range(num_layers)
]
self.start_from_zero_embedding = False
- self.multi_query_attention = multi_query_attention or (
+ self._config["multi_query_attention"] = multi_query_attention or (
num_heads_kv != num_heads
)
@@ -223,6 +231,15 @@ def __init__(
self.project_in = common_spec.LinearSpec()
self.project_out = common_spec.LinearSpec()
+ if quant_type is not None:
+ self._config["quantization_type"] = quant_type
+ self._config["quantization_bits"] = quant_bits
+ self._config["quantization_group_size"] = quant_group_size
+
+ @property
+ def config(self):
+ return self._config
+
class TransformerEncoderLayerSpec(model_spec.LayerSpec):
def __init__(
@@ -485,9 +502,8 @@ def __init__(self, decoder: TransformerDecoderSpec):
super().__init__()
self.decoder = decoder
- self._config.add_attribute(
- "multi_query_attention", self.decoder.multi_query_attention
- )
+ for key, value in self.decoder.config.items():
+ self._config.add_attribute(key, value)
@classmethod
def from_config(
@@ -518,6 +534,9 @@ def from_config(
num_heads_kv: Optional[int] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
+ quant_type: Optional[common_spec.Quantization] = None,
+ quant_group_size: Optional[int] = None,
+ quant_bits: Optional[int] = None,
):
"""Creates a Transformer decoder model specification.
@@ -553,7 +572,11 @@ def from_config(
attention layer norms.
multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
num_heads_kv: Number of attention heads for the key and value.
+ head_dim: Number of head
sliding_window: max sequence length to retain KV cache
+ quant_type: quantization type used (like awq... for lower bit quantization)
+ quant_group_size: group size of the lower bit quantization
+ quant_bits: number of bit of the quantization (ex: 4bit)
"""
decoder = TransformerDecoderSpec(
num_layers,
@@ -583,6 +606,9 @@ def from_config(
num_heads_kv=num_heads_kv,
head_dim=head_dim,
sliding_window=sliding_window,
+ quant_type=quant_type,
+ quant_group_size=quant_group_size,
+ quant_bits=quant_bits,
)
return cls(decoder)
diff --git a/src/layers/common.cc b/src/layers/common.cc
index 6f56c01ef..86fb66a7d 100644
--- a/src/layers/common.cc
+++ b/src/layers/common.cc
@@ -271,6 +271,7 @@ namespace ctranslate2 {
, _weight(get_linear_weight(model, scope, &_packed_weight))
, _bias(model.get_variable_if_exists(scope + "/bias"))
, _qscale(model.get_variable_if_exists(scope + "/weight_scale"))
+ , _qzero(model.get_variable_if_exists(scope + "/weight_zero"))
, _u8_shift_compensation((_weight.device() == Device::CPU
&& _weight.dtype() == DataType::INT8
&& cpu::prefer_u8s8s32_gemm())
@@ -281,6 +282,7 @@ namespace ctranslate2 {
, _partial_qscale(_weight.device(), DataType::FLOAT32)
, _partial_u8_shift_compensation(_weight.device(), DataType::INT32)
, _output_type(get_default_float_type(model.effective_compute_type()))
+ , _quant_method(model.quant_method())
, _quantized_gemm(_weight.dtype() == DataType::INT16 || _weight.dtype() == DataType::INT8)
, _gemm_op(/*alpha=*/1,
/*beta=*/0,
@@ -295,6 +297,7 @@ namespace ctranslate2 {
/*shift_to_uint8=*/bool(_u8_shift_compensation),
/*round_before_cast=*/model.round_before_cast_in_quantization())
, _dequantize_op(activation_type)
+ , _activation_type(activation_type)
, _is_layer_out(is_layer_out)
{
}
@@ -392,6 +395,38 @@ namespace ctranslate2 {
/*trans_b=*/true,
output,
bias);
+ } else if (_qzero && _qscale) {
+ switch (_quant_method) {
+ case models::QUANTIZATION_TYPE::AWQ_GEMM:
+ if (input.dim(0) * input.dim(1) >= 1024) {
+ StorageView weight_dequant(input.dtype(), input.device());
+ ops::DequantizeAwq dequantize_awq_op;
+ dequantize_awq_op(*weight, *qscale, *_qzero, weight_dequant);
+ ops::Gemm gemm_op(/*alpha=*/1,
+ /*beta=*/0,
+ /*trans_a=*/false,
+ /*trans_b=*/false,
+ /*a_is_packed=*/false,
+ /*b_is_packed*/false,
+ _activation_type);
+ gemm_op(input, weight_dequant, output, nullptr, bias);
+ } else {
+ ops::GemmAwq gemm_awq_op(/*alpha=*/1, /*beta=*/0, /*trans_a=*/false, /*trans_b=*/false,
+ /*a_is_packed=*/false, /*b_is_packed=*/false, _activation_type);
+ gemm_awq_op(input, *weight, *qscale, *_qzero, output, bias);
+ }
+ break;
+ case models::QUANTIZATION_TYPE::AWQ_GEMV:
+ {
+ ops::GemvAwq gemv_awq_op(/*alpha=*/1, /*beta=*/0, /*trans_a=*/false, /*trans_b=*/false,
+ /*a_is_packed=*/false, /*b_is_packed=*/false, _activation_type);
+ gemv_awq_op(input, *weight, *qscale, *_qzero, output, bias);
+ break;
+ }
+ default:
+ throw std::invalid_argument("Dense forward: invalid quantized type,"
+ "support only ct2 and awq quantization");
+ }
} else {
_gemm_op(input, *weight, output, nullptr, bias);
}
diff --git a/src/models/model.cc b/src/models/model.cc
index dd7273d57..b8e1c2d8f 100644
--- a/src/models/model.cc
+++ b/src/models/model.cc
@@ -173,7 +173,7 @@ namespace ctranslate2 {
_device_index = index;
}
- void Model::set_compute_type(ComputeType type, Device device, int device_index) {
+ void Model::set_compute_type(ComputeType type, Device device, int device_index, bool update_weight) {
if (_device != Device::CPU)
throw std::runtime_error("set_compute_type expects the variables to be on CPU");
@@ -187,46 +187,47 @@ namespace ctranslate2 {
device,
device_index);
- DataType weight_dtype = DataType::FLOAT32;
- DataType float_dtype = DataType::FLOAT32;
- std::tie(weight_dtype, float_dtype) = compute_type_to_data_type(_effective_compute_type);
- if (_use_flash_attention && (float_dtype != DataType::FLOAT16 && float_dtype != DataType::BFLOAT16))
- throw std::runtime_error("FlashAttention only support fp16 and bf16 data type");
-
- const auto variable_index = _variable_index;
- for (auto& variable_pair : variable_index) {
- const auto& name = variable_pair.first;
- auto& variable = *variable_pair.second;
-
- // Convert "weight" variables to the expected compute type.
- // Other float variables (e.g. biases) may be converted to another float type.
- if (is_quantizable(name)) {
- auto variable_weight_dtype = weight_dtype;
- // For conv layer, we need to reshape to ensure dtype as its weights are 3D.
- auto is_conv = name.find("conv") != std::string::npos;
- auto kernel_size = -1;
- if (is_conv) {
- kernel_size = variable.dim(2);
- variable.reshape({variable.dim(0), variable.dim(1) * variable.dim(2)});
- // For CUDA and DNNL backend, quantized convolution is not supported. Hence, convert to float_dtype.
- if (device == Device::CUDA
- #ifdef CT2_WITH_DNNL
- || true
- #endif
- ) {
- variable_weight_dtype = float_dtype;
+ if (update_weight) {
+ DataType weight_dtype = DataType::FLOAT32;
+ DataType float_dtype = DataType::FLOAT32;
+ std::tie(weight_dtype, float_dtype) = compute_type_to_data_type(_effective_compute_type);
+ if (_use_flash_attention && (float_dtype != DataType::FLOAT16 && float_dtype != DataType::BFLOAT16))
+ throw std::runtime_error("FlashAttention only support fp16 and bf16 data type");
+
+ const auto variable_index = _variable_index;
+ for (auto& variable_pair : variable_index) {
+ const auto &name = variable_pair.first;
+ auto &variable = *variable_pair.second;
+
+ // Convert "weight" variables to the expected compute type.
+ // Other float variables (e.g. biases) may be converted to another float type.
+ if (is_quantizable(name)) {
+ auto variable_weight_dtype = weight_dtype;
+ // For conv layer, we need to reshape to ensure dtype as its weights are 3D.
+ auto is_conv = name.find("conv") != std::string::npos;
+ auto kernel_size = -1;
+ if (is_conv) {
+ kernel_size = variable.dim(2);
+ variable.reshape({variable.dim(0), variable.dim(1) * variable.dim(2)});
+ // For CUDA and DNNL backend, quantized convolution is not supported. Hence, convert to float_dtype.
+ if (device == Device::CUDA
+#ifdef CT2_WITH_DNNL
+ || true
+#endif
+ ) {
+ variable_weight_dtype = float_dtype;
+ }
}
- }
- ensure_dtype(name, variable, variable_weight_dtype);
- // Undo reshape for conv weights
- if (is_conv) {
- variable.reshape({variable.dim(0), variable.dim(1) / kernel_size, kernel_size});
- }
+ ensure_dtype(name, variable, variable_weight_dtype);
+ // Undo reshape for conv weights
+ if (is_conv) {
+ variable.reshape({variable.dim(0), variable.dim(1) / kernel_size, kernel_size});
+ }
+ } else if (is_convertible(variable, name)
+ && is_float_type(variable.dtype())
+ && variable.dtype() != float_dtype)
+ variable = variable.to(float_dtype);
}
- else if (is_convertible(variable, name)
- && is_float_type(variable.dtype())
- && variable.dtype() != float_dtype)
- variable = variable.to(float_dtype);
}
}
@@ -637,6 +638,10 @@ namespace ctranslate2 {
" the config.json could lead to error! Try using the latest version of converters");
}
+ QUANTIZATION_TYPE quantization_type = QUANTIZATION_TYPE::CT2;
+ if (model->config.contains("quantization_type"))
+ model->set_quant_method(model->config["quantization_type"]);
+
for (uint32_t i = 0; i < num_variables; ++i) {
auto name = consume(model_file);
const size_t rank = consume(model_file);
@@ -740,12 +745,24 @@ namespace ctranslate2 {
variable = std::move(outputs[current_index]);
}
}
-
model->register_variable(std::move(name), std::move(variable));
}
// Maybe quantize/dequantize/convert the variables to match the requested compute type.
- model->set_compute_type(compute_type, device, device_index);
+ // if model is quantized with a specific type different with CT2, it use the specific kernel
+ // So have to keep the compute type for it.
+ switch (model->quant_method()) {
+ case QUANTIZATION_TYPE::CT2:
+ model->set_compute_type(compute_type, device, device_index);
+ break;
+ case QUANTIZATION_TYPE::AWQ_GEMM:
+ case QUANTIZATION_TYPE::AWQ_GEMV:
+ model->set_compute_type(ComputeType::FLOAT16, device, device_index, false);
+ break;
+ default:
+ throw std::invalid_argument("Quantization type is not supported");
+ break;
+ }
// Move variables to the target device.
model->set_device(device, device_index);
@@ -759,6 +776,7 @@ namespace ctranslate2 {
model->register_variable_alias(alias, variable_name);
// Also alias the quantization scale that could be associated to variable_name.
model->register_variable_alias(alias + "_scale", variable_name + "_scale");
+ model->register_variable_alias(alias + "_zero", variable_name + "_zero");
}
}
diff --git a/src/ops/awq/dequantize.cc b/src/ops/awq/dequantize.cc
new file mode 100644
index 000000000..eb909e087
--- /dev/null
+++ b/src/ops/awq/dequantize.cc
@@ -0,0 +1,24 @@
+#include
+
+#include "dispatch.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+ DequantizeAwq::DequantizeAwq() = default;
+
+ void DequantizeAwq::operator()(const StorageView& input,
+ const StorageView& scale,
+ const StorageView& zeros,
+ StorageView& output) const{
+ PROFILE("Dequantize Awq");
+
+ if (input.dtype() != DataType::INT32 && output.dtype() != DataType::FLOAT16)
+ throw std::invalid_argument("Awq dequantization is only supported for int32 input and float16 output");
+ if (input.device() == Device::CPU)
+ throw std::invalid_argument("Awq dequantization is only supported on GPU");
+
+ DEVICE_DISPATCH(input.device(), (dequantize(input, scale, zeros, output)));
+ }
+ }
+}
diff --git a/src/ops/awq/dequantize.cuh b/src/ops/awq/dequantize.cuh
new file mode 100644
index 000000000..799f68c3d
--- /dev/null
+++ b/src/ops/awq/dequantize.cuh
@@ -0,0 +1,79 @@
+#pragma once
+#include
+
+namespace ctranslate2 {
+ namespace ops {
+ __device__ __forceinline__ int make_divisible(int c, int divisor){
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+ assert(false);
+#else
+ return (c + divisor - 1) / divisor;
+#endif
+ }
+
+ __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
+ {
+ uint4 result;
+
+ uint32_t* h = reinterpret_cast(&result);
+ uint32_t const i4s = reinterpret_cast(source);
+
+ // First, we extract the i4s and construct an intermediate fp16 number.
+ static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
+ static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
+ static constexpr uint32_t TOP_MASK = 0x00f000f0;
+ static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
+
+ // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
+ // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
+ // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
+ // elt_67 to fp16 without having to shift them to the bottom bits before hand.
+
+ // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
+ // immediately before required.
+ const uint32_t top_i4s = i4s >> 8;
+ // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+ : "=r"(h[0])
+ : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+ // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+ : "=r"(h[1])
+ : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+ // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+ : "=r"(h[2])
+ : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+ // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+ : "=r"(h[3])
+ : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+
+ // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
+ // half2 ctor. In this case, I chose performance reliability over code readability.
+
+ // This is the half2 {1032, 1032} represented as an integer.
+ // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
+ // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
+ static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
+ // This is the half2 {1 / 16, 1 / 16} represented as an integer.
+ static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
+ // This is the half2 {-72, -72} represented as an integer.
+ // static constexpr uint32_t NEG_72 = 0xd480d480;
+ // Haotian: Let's use {-64, -64}.
+ static constexpr uint32_t NEG_64 = 0xd400d400;
+
+ // Finally, we construct the output numbers.
+ // Convert elt_01
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
+ // Convert elt_23
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
+ // Convert elt_45
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
+ // Convert elt_67
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
+
+ return result;
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/ops/awq/dequantize_cpu.cc b/src/ops/awq/dequantize_cpu.cc
new file mode 100644
index 000000000..76cf4c3a9
--- /dev/null
+++ b/src/ops/awq/dequantize_cpu.cc
@@ -0,0 +1,23 @@
+#include
+
+namespace ctranslate2 {
+ namespace ops {
+ template
+ void DequantizeAwq::dequantize(const StorageView&,
+ const StorageView&,
+ const StorageView&,
+ StorageView&) const {
+ throw std::runtime_error("AWQ dequantize is not applied for the cpu");
+ }
+
+#define DECLARE_IMPL(T) \
+ template void \
+ DequantizeAwq::dequantize( \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ StorageView&) const;
+
+ DECLARE_IMPL(float16_t)
+ }
+}
diff --git a/src/ops/awq/dequantize_gpu.cu b/src/ops/awq/dequantize_gpu.cu
new file mode 100644
index 000000000..f63361035
--- /dev/null
+++ b/src/ops/awq/dequantize_gpu.cu
@@ -0,0 +1,115 @@
+#include
+#include "dequantize.cuh"
+#include "cuda/helpers.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+ __global__ void __launch_bounds__(64) dequantize_weights(const int* __restrict__ B, // 4096x64 4096 rows 64 cols
+ const half * __restrict__ scaling_factors, // 32x512 32 rows 512 cols
+ const int* __restrict__ zeros, // 32x64 32 rows 64 cols
+ half * __restrict__ C, // 4096x512 4096 rows 512 cols
+ int G,
+ int in_c,
+ int out_c)
+ {
+ if (blockIdx.z > 0) {
+ B = B + blockIdx.z * in_c * out_c / 8;
+ scaling_factors = scaling_factors + blockIdx.z * in_c * out_c / G;
+ zeros = zeros + blockIdx.z * in_c * out_c / G / 8;
+ C = C + blockIdx.z * in_c * out_c;
+ }
+ static constexpr uint32_t ZERO = 0x0;
+ half B_shared[32 * (128 + 8)];
+
+ half* B_shared_ptr2 = B_shared;
+
+ int N = blockDim.x * gridDim.x; // 2
+ int col = (blockIdx.x * blockDim.x + threadIdx.x);
+ int row = blockIdx.y * blockDim.y + threadIdx.y;
+ int index1 = 8 * col + 8 * row * N; // + i (<8)
+ half* C_ptr2 = C + index1;
+
+ int index2 = col + row * N;
+ const int* B_ptr2 = B + index2;
+
+ int index3 = col + (int)(row / G) * N;
+ const int* zeros_ptr2 = zeros + index3;
+ int index4 = 8 * col + (int)(row / G) * N * 8; // + i (<8)
+ const half* scaling_factors_ptr2 = scaling_factors + index4;
+
+
+ uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
+ uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
+ uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
+ int j=0;
+
+ uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j);
+ uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
+
+ *(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16;
+
+ for (int i=0; i<8; ++i) {
+ *(C_ptr2 + i) = B_shared[i];
+ }
+ }
+
+ template
+ void DequantizeAwq::dequantize(const StorageView& input,
+ const StorageView& scale,
+ const StorageView& zero,
+ StorageView& output) const {
+ dim_t in_c = input.rank() == 2 ? input.dim(0) : input.dim(1);
+ dim_t qout_c = input.rank() == 2 ? input.dim(1) : input.dim(2);
+ int num_experts = input.rank() == 2 ? 1 : input.dim(0);
+ int out_c = qout_c * 8;
+ int G = in_c / (input.rank() == 2 ? scale.dim(0) : scale.dim(1));
+
+ int x_thread = 0 /*thx*/;
+ int y_thread = 0 /*thy*/;
+
+ int x_blocks = 1;
+ int y_blocks = 1;
+ x_thread = qout_c;
+ y_thread = in_c;
+
+ x_thread = 8;
+ y_thread = 8;
+ x_blocks = (int)(qout_c / 8);
+ y_blocks = (int)(in_c / 8);
+ if (num_experts == 1) {
+ output.resize({in_c, out_c});
+ } else {
+ output.resize({num_experts, in_c, out_c});
+ }
+
+ auto output_data = reinterpret_cast(output.data());
+ const auto scale_data = reinterpret_cast(scale.data());
+
+ dim3 num_blocks(x_blocks, y_blocks, num_experts);
+ dim3 threads_per_block(x_thread, y_thread); // col, row 64x4096
+
+ dequantize_weights<<>>(input.data(), scale_data,
+ zero.data(), output_data, G, in_c, out_c);
+ }
+
+#define DECLARE_IMPL(T) \
+ template void \
+ DequantizeAwq::dequantize( \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ StorageView&) const;
+
+ DECLARE_IMPL(float16_t)
+
+ }
+}
diff --git a/src/ops/awq/gemm.cc b/src/ops/awq/gemm.cc
new file mode 100644
index 000000000..bb3834213
--- /dev/null
+++ b/src/ops/awq/gemm.cc
@@ -0,0 +1,34 @@
+#include
+#include
+#include "dispatch.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+ void GemmAwq::operator()(const StorageView& a,
+ const StorageView& b,
+ const StorageView& scale,
+ const StorageView& zero,
+ StorageView& c,
+ const StorageView* bias) const {
+ PROFILE("Gemm Awq");
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+ throw std::runtime_error("AWQ Gemm does not support for cuda arch < 7.5");
+#else
+ if (a.dtype() != DataType::FLOAT16 && b.dtype() != DataType::INT32)
+ throw std::invalid_argument("Awq gemm is only supported for float16 input and int32 weight");
+ if (a.device() == Device::CPU)
+ throw std::invalid_argument("Awq gemm is only supported on GPU");
+
+ DEVICE_DISPATCH(a.device(), (compute(a, b, scale, zero, c)));
+
+ StorageView tmp(c.dtype(), c.device());
+ ops::Sum(0)(c, tmp);
+ tmp.squeeze(0);
+ c = std::move(tmp);
+
+ apply_bias_and_activation(c, bias, _activation_type);
+#endif
+ }
+ }
+}
diff --git a/src/ops/awq/gemm_cpu.cc b/src/ops/awq/gemm_cpu.cc
new file mode 100644
index 000000000..010385659
--- /dev/null
+++ b/src/ops/awq/gemm_cpu.cc
@@ -0,0 +1,25 @@
+#include
+
+namespace ctranslate2 {
+ namespace ops {
+ template
+ void GemmAwq::compute(const StorageView&,
+ const StorageView&,
+ const StorageView&,
+ const StorageView&,
+ StorageView&) const {
+ throw std::runtime_error("AWQ gemm is not applied for the cpu");
+ }
+
+#define DECLARE_IMPL(T) \
+ template void \
+ GemmAwq::compute( \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ StorageView&) const;
+
+ DECLARE_IMPL(float16_t)
+ }
+}
diff --git a/src/ops/awq/gemm_gpu.cu b/src/ops/awq/gemm_gpu.cu
new file mode 100644
index 000000000..544ce8559
--- /dev/null
+++ b/src/ops/awq/gemm_gpu.cu
@@ -0,0 +1,543 @@
+#include "cuda/utils.h"
+#include "dequantize.cuh"
+#include
+#include
+
+namespace ctranslate2 {
+ namespace ops {
+ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, const half* __restrict__ A, const int* __restrict__ B, const half* __restrict__ scaling_factors, const int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
+ {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+ assert(false);
+#else
+ static constexpr uint32_t ZERO = 0x0;
+ float C_warp[32];
+ __shared__ half A_shared[16 * (32 + 8)];
+ __shared__ half B_shared[32 * (128 + 8)];
+
+ int j_factors1 = ((OC + 128 - 1) / 128);
+ int blockIdx_x = 0;
+ int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
+ int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
+
+ half A_shared_warp[8];
+ half B_shared_warp[32];
+ for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
+ for (int i = 0; i < 8; ++i) {
+ C_warp[(j_0_4_init * 8) + i] = 0.0;
+ }
+ }
+
+ static constexpr int row_stride_warp = 32 * 8 / 32;
+ static constexpr int row_stride = 2 * 32 * 8 / 128;
+ bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
+ // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+ bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
+ // bool wb_C_flag = (threadIdx.x / 4) < M;
+
+ const half* A_ptr = A
+ + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ + (((int)threadIdx.x) % (32 / 8)) * 8;
+
+ const int* B_ptr = B
+ + ((int)threadIdx.y) * (OC / 8) * 2
+ + (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+ + (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ + (((int)threadIdx.x) % (128 / 8)) * 1;
+// Why * 1 in the above line?
+
+ half* A_shared_ptr = A_shared
+ + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ + (((int)threadIdx.x) % (32 / 8) ) * 8;
+
+ half* B_shared_ptr = B_shared
+ + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
+ + (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
+ + (((int)threadIdx.x) % (128 / 8)) * 8;
+
+ const int* zeros_ptr = zeros
+ + (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ + ((int)threadIdx.x) % (128 / 8);
+
+ const half* scaling_factors_ptr = scaling_factors
+ + (((int)blockIdx_y) % j_factors1) * (128)
+ + (((int)threadIdx.x) % (128 / 8)) * 8;
+
+ half* C_ptr = C
+ + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ + (((int)blockIdx_y) % j_factors1) * 128
+ + ((int)threadIdx.y) * 64
+ + (((int)threadIdx.x) % 4) * 2;
+
+ // preload s.f. and zeros
+ int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
+ if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
+ for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
+ int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
+ __syncthreads();
+ // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+ if (ld_A_flag)
+ {
+ *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
+ }
+ else
+ {
+ *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
+ }
+
+ // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
+ uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
+ uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
+ uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
+ /*
+ if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
+ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
+ }
+ */
+ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
+ const int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
+
+ for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
+
+ // B: 32 x 136 (128+8) float16
+ // each warp: 32 x 4
+ // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
+ // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
+ // row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
+ uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
+ uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
+ //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
+
+ // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
+ // - zero and * scale
+ // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
+ /*
+ if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
+ printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
+ }
+ */
+
+ // write back
+ *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
+ }
+ __syncthreads();
+
+ for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
+ {
+ unsigned int addr;
+ asm volatile(
+ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+ : "=r"(addr)
+ : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
+ );
+
+
+ asm volatile(
+ "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+ "{%0, %1, %2, %3}, [%4];\n"
+ : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
+ : "r"(addr)
+ );
+ }
+
+ for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
+ {
+ unsigned int addr;
+ asm volatile(
+ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+ : "=r"(addr)
+ : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
+ );
+ asm volatile(
+ "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+ "{%0, %1, %2, %3}, [%4];\n"
+ : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
+ : "r"(addr)
+ );
+ }
+ }
+ for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+ }
+#else
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+ : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+ : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+ }
+#endif
+ }
+ }
+ }
+
+// TODO: Shang: Hoist loop invariance.
+ for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
+ for (int local_id = 0; local_id < 8; ++local_id) {
+ int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
+ if (row_offset < M)
+ {
+ *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
+ }
+ }
+ }
+#endif
+ }
+
+
+ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, const half* __restrict__ A, const int* __restrict__ B, const half* __restrict__ const scaling_factors, const int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
+ {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+ assert(false);
+#else
+ static constexpr uint32_t ZERO = 0x0;
+ float C_warp[32];
+ __shared__ half A_shared[16 * (32 + 8)];
+ __shared__ half B_shared[32 * (64 + 8)];
+
+ __shared__ half scaling_factors_shared[64];
+ __shared__ half zeros_shared[64];
+
+ int j_factors1 = ((OC + 64 - 1) / 64);
+
+ int blockIdx_x = 0;
+ int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
+ int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
+
+ half A_shared_warp[8];
+ half B_shared_warp[16];
+ for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
+ for (int i = 0; i < 8; ++i) {
+ C_warp[(j_0_4_init * 8) + i] = 0.0;
+ }
+ }
+
+ static constexpr int row_stride_warp = 32 * 8 / 32;
+ static constexpr int row_stride = 2 * 32 * 8 / 64;
+ bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
+ // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+ bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
+ // bool wb_C_flag = (threadIdx.x / 4) < M;
+
+ const half* A_ptr = A
+ + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ + (((int)threadIdx.x) % (32 / 8)) * 8;
+
+ const int* B_ptr = B
+ + ((int)threadIdx.y) * (OC / 8) * 4
+ + (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
+ + (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ + (((int)threadIdx.x) % (64 / 8)) * 1;
+// Why * 1 in the above line?
+
+ half* A_shared_ptr = A_shared
+ + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ + (((int)threadIdx.x) % (32 / 8) ) * 8;
+
+ half* B_shared_ptr = B_shared
+ + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
+ + (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
+ + (((int)threadIdx.x) % (64 / 8)) * 8;
+
+ const int* zeros_ptr = zeros
+ + (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ + ((int)threadIdx.x) % (64 / 8);
+
+ const half* scaling_factors_ptr = scaling_factors
+ + (((int)blockIdx_y) % j_factors1) * (64)
+ + (((int)threadIdx.x) % (64 / 8)) * 8;
+
+ half* C_ptr = C
+ + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ + (((int)blockIdx_y) % j_factors1) * 64
+ + ((int)threadIdx.y) * 32
+ + (((int)threadIdx.x) % 4) * 2;
+
+ // preload s.f. and zeros
+ int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
+ if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
+ for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
+ int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
+ __syncthreads();
+ // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+ if (ld_A_flag)
+ {
+ *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
+ }
+ else
+ {
+ *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
+ }
+
+ // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
+ uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
+ uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
+ uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
+ /*
+ if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
+ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
+ }
+ */
+ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
+ const int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
+
+ for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
+
+ // B: 32 x 136 (128+8) float16
+ // each warp: 32 x 4
+ // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
+ // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
+ // row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
+ uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
+ uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
+ //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
+
+ // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
+ // - zero and * scale
+ // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
+ asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
+ /*
+ if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
+ printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
+ }
+ */
+
+ // write back
+ *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
+ }
+ __syncthreads();
+
+ for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
+ {
+ {
+ unsigned int addr;
+ asm volatile(
+ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+ : "=r"(addr)
+ : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
+ );
+ asm volatile(
+ "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+ "{%0, %1, %2, %3}, [%4];\n"
+ : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
+ : "r"(addr)
+ );
+ }
+
+
+ for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
+ {
+ {
+ unsigned int addr;
+ asm volatile(
+ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+ : "=r"(addr)
+ : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
+ );
+ asm volatile(
+ "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+ "{%0, %1, %2, %3}, [%4];\n"
+ : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
+ : "r"(addr)
+ );
+ }
+ }
+
+ for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
+ {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+ }
+#else
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+ : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+ : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
+ : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+ }
+#endif
+ }
+ }
+ }
+
+// TODO: Shang: Hoist loop invariance.
+ for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
+ for (int local_id = 0; local_id < 8; ++local_id) {
+ int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
+ if (row_offset < M)
+ {
+ *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
+ }
+ }
+ }
+#endif
+ }
+
+ template
+ void GemmAwq::compute(const StorageView& a,
+ const StorageView& b,
+ const StorageView& scale,
+ const StorageView& zero,
+ StorageView& c) const {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+ assert(false);
+#else
+ dim_t num_in_channels = a.dim(-1);
+ dim_t num_in_feats = a.size() / num_in_channels;
+ dim_t split_k_iters = 8;
+
+ if (a.rank() == 2)
+ c.resize({split_k_iters, num_in_feats, b.dim(1) * 8});
+ else if (a.rank() == 3)
+ c.resize({split_k_iters, a.dim(0), a.dim(1), b.dim(1) * 8});
+
+ dim_t num_out_feats = num_in_feats;
+ dim_t num_out_channels = c.dim(-1);
+
+ const auto a_data = reinterpret_cast(a.data());
+ const auto b_data = reinterpret_cast(b.data());
+ auto output_data = reinterpret_cast(c.data());
+ const auto scale_data = reinterpret_cast(scale.data());
+ const auto zero_data = reinterpret_cast(zero.data());
+ dim_t group_size = num_in_channels / scale.dim(0);
+
+ if (num_out_channels % 64 != 0)
+ throw std::invalid_argument("OC is not multiple of cta_N = 64");
+ if (num_out_channels % 8 != 0)
+ throw std::invalid_argument("OC is not multiple of pack_num = 8");
+ if (group_size % 32 != 0)
+ throw std::invalid_argument("Group size should be a multiple of 32");
+ if (num_out_channels % group_size != 0)
+ throw std::invalid_argument("OC is not multiple of Group size");
+
+ if (num_out_channels % 128 == 0)
+ {
+ dim_t j_factors1 = num_out_channels / 128 / 1;
+ dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
+ // threadIdx.x: 32
+ // threadIdx.y: i_factors[2] * j_factors[2]
+ dim3 threads_per_block(32, 2);
+ gemm_forward_4bit_cuda_m16n128k32<<>>(
+ group_size, split_k_iters, a_data, b_data, scale_data, zero_data, num_in_feats, num_in_channels, num_out_channels, output_data);
+ }
+ else if (num_out_channels % 64 == 0)
+ {
+ int j_factors1 = num_out_channels / 64 / 1;
+ dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
+
+ // threadIdx.x: 32
+ // threadIdx.y: i_factors[2] * j_factors[2]
+ dim3 threads_per_block(32, 2);
+ gemm_forward_4bit_cuda_m16n64k32<<>>(
+ group_size, split_k_iters, a_data, b_data, scale_data, zero_data, num_in_feats, num_in_channels, num_out_channels, output_data);
+ }
+#endif
+ }
+
+
+#define DECLARE_IMPL(T) \
+ template void \
+ GemmAwq::compute( \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ StorageView&) const;
+
+ DECLARE_IMPL(float16_t)
+ }
+}
\ No newline at end of file
diff --git a/src/ops/awq/gemv.cc b/src/ops/awq/gemv.cc
new file mode 100644
index 000000000..0245ac1af
--- /dev/null
+++ b/src/ops/awq/gemv.cc
@@ -0,0 +1,39 @@
+#include
+#include
+#include
+
+#include "dispatch.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+ void GemvAwq::operator()(const StorageView& a,
+ const StorageView& b,
+ const StorageView& scale,
+ const StorageView& zero,
+ StorageView& c,
+ const StorageView* bias) const {
+ PROFILE("Gemv Awq");
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+ throw std::runtime_error("AWQ Gemv does not support for cuda arch < 7.5");
+#else
+ if (a.dtype() != DataType::FLOAT16 && b.dtype() != DataType::INT32)
+ throw std::invalid_argument("Awq gemm is only supported for float16 input and int32 weight");
+ if (a.device() == Device::CPU)
+ throw std::invalid_argument("Awq gemm is only supported on GPU");
+
+ if (a.dim(0) > 8) {
+ DEVICE_DISPATCH(a.device(), (compute_gemv2(a, b, scale, zero, c)));
+ StorageView tmp(c.dtype(), c.device());
+ ops::Sum(0)(c, tmp);
+ tmp.squeeze(0);
+ c = std::move(tmp);
+ }
+ else
+ DEVICE_DISPATCH(a.device(), (compute_gemv(a, b, scale, zero, c)));
+
+ apply_bias_and_activation(c, bias, _activation_type);
+#endif
+ }
+ }
+}
diff --git a/src/ops/awq/gemv_cpu.cc b/src/ops/awq/gemv_cpu.cc
new file mode 100644
index 000000000..77603da11
--- /dev/null
+++ b/src/ops/awq/gemv_cpu.cc
@@ -0,0 +1,40 @@
+#include
+
+namespace ctranslate2 {
+ namespace ops {
+ template
+ void GemvAwq::compute_gemv(const StorageView&,
+ const StorageView&,
+ const StorageView&,
+ const StorageView&,
+ StorageView&) const {
+ throw std::runtime_error("AWQ gemv is not applied for the cpu");
+ }
+ template
+ void GemvAwq::compute_gemv2(const StorageView&,
+ const StorageView&,
+ const StorageView&,
+ const StorageView&,
+ StorageView&) const {
+ throw std::runtime_error("AWQ gemv2 is not applied for the cpu");
+ }
+
+#define DECLARE_IMPL(T) \
+ template void \
+ GemvAwq::compute_gemv2( \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ StorageView&) const; \
+ template void \
+ GemvAwq::compute_gemv( \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ StorageView&) const;
+
+ DECLARE_IMPL(float16_t)
+ }
+}
diff --git a/src/ops/awq/gemv_gpu.cu b/src/ops/awq/gemv_gpu.cu
new file mode 100644
index 000000000..a24b92773
--- /dev/null
+++ b/src/ops/awq/gemv_gpu.cu
@@ -0,0 +1,548 @@
+#include "cuda/utils.h"
+#include "dequantize.cuh"
+#include
+#include
+#define PACK_FACTOR 8
+#define WARP_SIZE 32
+
+namespace ctranslate2 {
+ namespace ops {
+ template
+ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int split_k_iters, const half* __restrict__ A, const int* __restrict__ B, const half* __restrict__ scaling_factors, const int* zeros, int M, int IC, int OC, half* __restrict__ C)
+ {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+ assert(false);
+#else
+ static constexpr uint32_t ZERO = 0x0;
+ float C_warp[64];
+ __shared__ half A_shared[128 * (32 + 8)];
+ __shared__ half B_shared[64 * (32 + 8)];
+
+ // __shared__ half scaling_factors_shared[64];
+ // __shared__ half zeros_shared[64];
+
+ int j_factors1 = ((OC + 64 - 1) / 64);
+
+ //int blockIdx_x = 0;
+ int blockIdx_y = blockIdx.x % ((M + 128 - 1) / 128 * j_factors1);
+ int blockIdx_z = blockIdx.x / ((M + 128 - 1) / 128 * j_factors1);
+
+ half A_shared_warp[32];
+ half B_shared_warp[16];
+ for (int i_0_3_init = 0; i_0_3_init < 4; ++i_0_3_init) {
+ for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
+ for (int i = 0; i < 8; ++i) {
+ C_warp[((i_0_3_init * 16) + (j_0_4_init * 8)) + i] = 0.0;
+ }
+ }
+ }
+
+ static constexpr int row_stride_warp = 32 * 8 / 32;
+ static constexpr int row_stride_A = 4 * 32 * 8 / 32;
+ static constexpr int row_stride = 4 * 32 * 8 / 32;
+ const int make_divisible_multipler = 128 / G;
+ const int zeros_w = make_divisible(make_divisible(IC / G, 8), make_divisible_multipler) * make_divisible_multipler;
+ const int sf_w = zeros_w * 8;
+
+ int ld_A_row = (blockIdx_y / j_factors1 * 128 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32); // threadIdx.y is warp_id
+ // bool wb_C_flag = (threadIdx.x / 4) < M;
+
+ const half* A_ptr = A
+ + (((int)blockIdx_y) / j_factors1 * 128 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ + (((int)threadIdx.x) % (32 / 8)) * 8;
+
+ const int* B_ptr = B
+ + ((int)threadIdx.y) * (IC / 8) * 8
+ + (((int)threadIdx.x) / (32 / 8)) * (IC / 8)
+ + (((int)blockIdx_y) % j_factors1) * 64 * (IC / 8)
+ + (((int)threadIdx.x) % (32 / 8)) * 1;
+
+// Why * 1 in the above line?
+
+ half* A_shared_ptr = A_shared
+ + ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ + (((int)threadIdx.x) % (32 / 8) ) * 8;
+
+ half* B_shared_ptr = B_shared
+ + ((int)threadIdx.y) * (row_stride / 4) * (32 + 8)
+ + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ + (((int)threadIdx.x) % (32 / 8)) * 8;
+
+
+ const int* zeros_ptr = zeros
+ + ((int)threadIdx.y) * zeros_w * 8
+ + (((int)threadIdx.x) / (32 / 8)) * zeros_w
+ + (((int)blockIdx_y) % j_factors1) * 64 * zeros_w
+ // this term is zero
+ + (((int)threadIdx.x) % (32 / 8)) / G ;
+
+ const half* scaling_factors_ptr = scaling_factors
+ + ((int)threadIdx.y) * sf_w * 8
+ + (((int)threadIdx.x) / (32 / 8)) * sf_w
+ + (((int)blockIdx_y) % j_factors1) * (64) * sf_w
+ // this term is zero
+ + (((int)threadIdx.x) % (32 / 8)) * 8 / G;
+
+
+ // Haotian: TBD, check, May 29 11:46 AM PST
+ half* C_ptr = C
+ + static_cast(blockIdx_z) * M * OC // blockIdx_z -> split_k dim
+ + (((int)blockIdx_y) % j_factors1) * 64
+ + (((int)threadIdx.y) / 2) * 32
+ + (((int)threadIdx.x) % 4) * 2;
+
+ // preload s.f. and zeros
+ int k_bound = make_divisible(IC / 32, split_k_iters); // (IC / 32 + split_k_iters - 1) / split_k_iters;
+ if ((k_bound - 1) * 32 + blockIdx_z >= IC) k_bound -= 1;
+
+ // TODO (Haotian): load scales and zero points to smem
+
+ for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
+ int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
+ __syncthreads();
+ // TODO: Haotian: Here we assume M % cta_M = 0.
+ for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0)
+ {
+ if (ld_A_row + ax0_ax1_fused_0 * row_stride_A < M)
+ {
+ *(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = *(uint4*)(A_ptr + (ax0_ax1_fused_0 * row_stride_A * IC) + (k_0_0 * 32));
+ }
+ else
+ {
+ *(uint4*)(A_shared_ptr + ax0_ax1_fused_0 * row_stride_A * 40) = make_uint4(0, 0, 0, 0);
+ }
+ }
+
+
+ const int* zeros_ptr_local = zeros_ptr + k_0_0 * 32 / G / 8;
+ const half* scaling_factors_ptr_local = scaling_factors_ptr + k_0_0 * 32 / G;
+
+ // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
+ const int* B_ptr_local = B_ptr + k_0_0 * (32 / 8);
+
+ for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
+
+ // B: 32 x 136 (128+8) float16
+ // each warp: 32 x 4
+ // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
+ // row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
+ int B_loaded_current = *(B_ptr_local + ax0_ax1_fused_0 * row_stride * (IC / 8));
+ int zeros_loaded = *(zeros_ptr_local + ax0_ax1_fused_0 * row_stride * zeros_w);
+ zeros_loaded >>= ((k_0_0 * 32 / G) % 8) * 4;
+ float current_zeros = (float)(zeros_loaded & 0xF);
+ half scaling_factors_loaded = *(scaling_factors_ptr_local + ax0_ax1_fused_0 * row_stride * sf_w);
+ half B_loaded_fp16[8];
+#pragma unroll
+ for (int ic_1 = 0; ic_1 < 8; ic_1++){
+ float current_single_weight_fp = (float)(B_loaded_current & 0xF);
+ half dequantized_weight = __float2half(__half2float(scaling_factors_loaded) * (current_single_weight_fp - current_zeros));
+ B_loaded_current = B_loaded_current >> 4;
+ B_loaded_fp16[ic_1] = dequantized_weight;
+ }
+ // write back
+ *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (32 + 8)) = *reinterpret_cast(B_loaded_fp16);
+ }
+ __syncthreads();
+ for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
+ for (int ax0_0 = 0; ax0_0 < 4; ++ax0_0) {
+ {
+ unsigned int addr;
+ asm volatile(
+ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+ : "=r"(addr)
+ : "l"((void *)((&(A_shared[((((((int)threadIdx.y) & 1) * 2560) + (ax0_0 * 640)) + (k_0_1 * 16))])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
+ );
+ asm volatile(
+ "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+ "{%0, %1, %2, %3}, [%4];\n"
+ : "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(A_shared_warp + (ax0_0 * 8)))[3])
+ : "r"(addr)
+ );
+ }
+ }
+
+ for (int ax0_0_1 = 0; ax0_0_1 < 2; ++ax0_0_1) {
+ {
+ unsigned int addr;
+ asm volatile(
+ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+ : "=r"(addr)
+ : "l"((void *)((&(B_shared[((((((int)threadIdx.y) >> 1) * 1280) + (ax0_0_1 * 640)) + (k_0_1 * 16))])) + ((((((int)threadIdx.x) >> 4) * 320) + ((((int)threadIdx.x) & 7) * 40)) + (((((int)threadIdx.x) & 15) >> 3) * 8))))
+ );
+ asm volatile(
+ "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+ "{%0, %1, %2, %3}, [%4];\n"
+ : "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax0_0_1 * 8)))[3])
+ : "r"(addr)
+ );
+ }
+ }
+
+ for (int i_0_3 = 0; i_0_3 < 4; ++i_0_3) {
+ for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
+ : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3])
+ : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8 + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
+ : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
+ : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3])
+ : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8 + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8) + 4)))[3]));
+ }
+#else
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+ : "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
+ : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]));
+ }
+
+ {
+ asm volatile(
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+ : "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])
+ : "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3]));
+ }
+#endif
+ }
+ }
+ }
+ }
+
+// Haotian: Here (May 29 11:46AM PST)
+// TODO: Shang: Hoist loop invariance.
+ for (int ax0_0_2 = 0; ax0_0_2 < 4; ++ax0_0_2) {
+ for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) {
+ for (int local_id = 0; local_id < 8; ++local_id) {
+ int row_offset = (((int)blockIdx_y) / j_factors1) * 128 + (threadIdx.y % 2) * 64 + ax0_0_2 * 16 + (local_id % 4) / 2 * 8 + ((int)threadIdx.x) / 4;
+ if (row_offset < M)
+ {
+ *(C_ptr + ax1_0 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax0_0_2 * 16) + (ax1_0 * 8) + local_id]);
+ }
+ }
+ }
+ }
+#endif
+ }
+
+ // Reduce sum within the warp using the tree reduction algorithm.
+ __device__ __forceinline__ float warp_reduce_sum(float sum) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+ assert(false);
+#else
+#pragma unroll
+ for(int i = 4; i >= 0; i--){
+ sum += __shfl_down_sync(0xffffffff, sum, 1<(zeros + oc_idx * zeros_w + packed_group_idx * 2);
+ uint32_t packed_weights[4];
+ // use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
+ *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4));
+ // load scaling factors
+ // g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups.
+ float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]);
+ float current_zeros = (float)((packed_zeros >> (threadIdx.x / 2 * 4)) & 0xF);
+ int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4;
+ const float4* inputs_ptr = inputs + inputs_ptr_delta;
+ // multiply 32 weights with 32 inputs
+#pragma unroll
+ for (int ic_0 = 0; ic_0 < 4; ic_0++){
+ // iterate over different uint32_t packed_weights in this loop
+ uint32_t current_packed_weight = packed_weights[ic_0];
+ half packed_inputs[PACK_FACTOR];
+ // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
+ if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) {
+ *((float4*)packed_inputs) = *(inputs_ptr + ic_0);
+#pragma unroll
+ for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){
+ // iterate over 8 numbers packed within each uint32_t number
+ float current_single_weight_fp = (float)(current_packed_weight & 0xF);
+ float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros);
+ //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
+ psum += dequantized_weight * __half2float(packed_inputs[ic_1]);
+ current_packed_weight = current_packed_weight >> 4;
+ }
+ }
+ }
+ }
+ psum = warp_reduce_sum(psum);
+ if (threadIdx.x == 0) {
+ outputs[oc_idx] = __float2half(psum);
+ }
+#endif
+ }
+
+
+/*
+Computes GEMV (group_size = 128).
+
+Args:
+ inputs: vector of shape [batch_size, IC];
+ weight: matrix of shape [OC, IC / 8];
+ output: vector of shape [OC];
+ zeros: matrix of shape [OC, IC / group_size / 8];
+ scaling_factors: matrix of shape [OC, IC / group_size];
+
+Notes:
+ One cannot infer group_size from the shape of scaling factors.
+ the second dimension is rounded up to a multiple of PACK_FACTOR.
+*/
+ __global__ void gemv_kernel_g128(
+ const float4* _inputs, const uint32_t* weight, const uint32_t* zeros, const half* scaling_factors, half* _outputs,
+ const int IC, const int OC){
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+ assert(false);
+#else
+ const int group_size = 128;
+ float psum = 0;
+ const int batch_idx = blockIdx.z;
+ const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y;
+ const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR;
+ half* outputs = _outputs + batch_idx * OC;
+ const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR);
+ const int weight_w = IC / PACK_FACTOR;
+ // TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
+ const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR);
+ // consistent with input shape
+ const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR;
+ //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w);
+ // tile size: 4 OC x 1024 IC per iter
+ for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){
+ // 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
+ uint32_t packed_zeros = *(zeros + oc_idx * zeros_w + packed_group_idx);
+ uint32_t packed_weights[4];
+ // use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
+ *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4));
+ // load scaling factors
+ // g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups.
+ float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]);
+ float current_zeros = (float)((packed_zeros >> (threadIdx.x / 4 * 4)) & 0xF);
+ int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4;
+ const float4* inputs_ptr = inputs + inputs_ptr_delta;
+ // multiply 32 weights with 32 inputs
+#pragma unroll
+ for (int ic_0 = 0; ic_0 < 4; ic_0++){
+ // iterate over different uint32_t packed_weights in this loop
+ uint32_t current_packed_weight = packed_weights[ic_0];
+ half packed_inputs[PACK_FACTOR];
+ // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
+ if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) {
+ *((float4*)packed_inputs) = *(inputs_ptr + ic_0);
+#pragma unroll
+ for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){
+ // iterate over 8 numbers packed within each uint32_t number
+ float current_single_weight_fp = (float)(current_packed_weight & 0xF);
+ float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros);
+ //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
+ psum += dequantized_weight * __half2float(packed_inputs[ic_1]);
+ current_packed_weight = current_packed_weight >> 4;
+ }
+ }
+ }
+ }
+ psum = warp_reduce_sum(psum);
+ if (threadIdx.x == 0) {
+ outputs[oc_idx] = __float2half(psum);
+ }
+#endif
+ }
+
+ template
+ void GemvAwq::compute_gemv(const StorageView& a,
+ const StorageView& b,
+ const StorageView& scale,
+ const StorageView& zero,
+ StorageView& c) const {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+ assert(false);
+#else
+ dim_t num_in_channels = a.dim(-1);
+ dim_t num_in_feats = a.size() / num_in_channels;
+ if (a.rank() == 2)
+ c.resize({num_in_feats, b.dim(0)});
+ else if (a.rank() == 3)
+ c.resize({a.dim(0), a.dim(1), b.dim(0)});
+
+ const auto a_data = reinterpret_cast(a.data());
+ const auto b_data = reinterpret_cast(b.data());
+ auto output_data = reinterpret_cast(c.data());
+ const auto scale_data = reinterpret_cast(scale.data());
+ const auto zero_data = reinterpret_cast(zero.data());
+ dim_t group_size = num_in_channels / scale.dim(-1);
+
+ dim_t num_out_feats = num_in_feats;
+ dim_t num_out_channels = c.dim(-1);
+ dim3 num_blocks(1, num_out_channels / 4, num_out_feats);
+ dim3 num_threads(32, 4);
+ if (group_size == 64)
+ {
+ gemv_kernel_g64<<>>(
+ // pointers
+ a_data, b_data, zero_data, scale_data, output_data,
+ // constants
+ num_in_channels, num_out_channels
+ );
+ }
+ else if (group_size == 128)
+ {
+ gemv_kernel_g128<<>>(
+ // pointers
+ a_data, b_data, zero_data, scale_data, output_data,
+ // constants
+ num_in_channels, num_out_channels
+ );
+ }
+#endif
+ }
+
+ template
+ void GemvAwq::compute_gemv2(const StorageView& a,
+ const StorageView& b,
+ const StorageView& scale,
+ const StorageView& zero,
+ StorageView& c) const {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
+ assert(false);
+#else
+ dim_t num_in_channels = a.dim(-1);
+ dim_t num_in_feats = a.size() / num_in_channels;
+ dim_t split_k_iters = 8;
+
+ if (a.rank() == 2)
+ c.resize({split_k_iters, num_in_feats, b.dim(0)});
+ else if (a.rank() == 3)
+ c.resize({split_k_iters, a.dim(0), a.dim(1), b.dim(0)});
+
+ dim_t num_out_feats = num_in_feats;
+ dim_t num_out_channels = c.dim(-1);
+
+ const auto a_data = reinterpret_cast(a.data());
+ const auto b_data = reinterpret_cast(b.data());
+ auto output_data = reinterpret_cast(c.data());
+ const auto scale_data = reinterpret_cast(scale.data());
+ const auto zero_data = reinterpret_cast(zero.data());
+ dim_t group_size = num_in_channels / scale.dim(-1);
+
+ if (num_out_channels % 64 != 0)
+ throw std::invalid_argument("OC is not multiple of cta_N = 64");
+ if (num_out_channels % 8 != 0)
+ throw std::invalid_argument("OC is not multiple of pack_num = 8");
+
+ int j_factors1 = num_out_channels / 64 / 1;
+ dim3 num_blocks((num_out_feats + 128 - 1) / 128 * j_factors1 * split_k_iters);
+
+ // threadIdx.x: 32
+ // threadIdx.y: i_factors[2] * j_factors[2]
+ dim3 threads_per_block(32, 4);
+ if (group_size == 128)
+ {
+ gemmv2_forward_4bit_cuda_m128n64k32<128><<>>(
+ split_k_iters, a_data, b_data, scale_data, zero_data, num_in_feats, num_in_channels, num_out_channels, output_data);
+ }
+ else if (group_size == 64)
+ {
+ gemmv2_forward_4bit_cuda_m128n64k32<64><<>>(
+ split_k_iters, a_data, b_data, scale_data, zero_data, num_in_feats, num_in_channels, num_out_channels, output_data);
+ }
+ else
+ {
+ throw std::invalid_argument("Group size temporarily not supported.");
+ }
+#endif
+ }
+
+
+#define DECLARE_IMPL(T) \
+ template void \
+ GemvAwq::compute_gemv2( \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ StorageView&) const; \
+ template void \
+ GemvAwq::compute_gemv( \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ const StorageView&, \
+ StorageView&) const;
+
+ DECLARE_IMPL(float16_t)
+ }
+}
\ No newline at end of file
diff --git a/src/ops/mean.cc b/src/ops/mean.cc
index 042f8c169..afac4cd82 100644
--- a/src/ops/mean.cc
+++ b/src/ops/mean.cc
@@ -39,7 +39,7 @@ namespace ctranslate2 {
inner_size *= input.dim(i);
DEVICE_AND_FLOAT_DISPATCH("Mean", input.device(), input.dtype(),
- (compute(input, outer_size, axis_size, inner_size, output)));
+ (compute(input, outer_size, axis_size, inner_size, false, output)));
}
}
diff --git a/src/ops/mean_cpu.cc b/src/ops/mean_cpu.cc
index be8fd9d90..a4edb12c8 100644
--- a/src/ops/mean_cpu.cc
+++ b/src/ops/mean_cpu.cc
@@ -11,6 +11,7 @@ namespace ctranslate2 {
const dim_t outer_size,
const dim_t axis_size,
const dim_t inner_size,
+ const bool get_sum,
StorageView& output) const {
const auto* src = input.data();
auto* dst = output.data();
@@ -22,7 +23,9 @@ namespace ctranslate2 {
for (dim_t k = 0; k < axis_size; ++k) {
sum += src[i * axis_size * inner_size + k * inner_size + j];
}
- dst[i * inner_size + j] = sum / float(axis_size);
+ dst[i * inner_size + j] = sum;
+ if (!get_sum)
+ dst[i * inner_size + j] /= float(axis_size);
}
}
});
@@ -34,6 +37,7 @@ namespace ctranslate2 {
const dim_t outer_size, \
const dim_t axis_size, \
const dim_t inner_size, \
+ const bool get_sum, \
StorageView& output) const;
DECLARE_IMPL(float)
diff --git a/src/ops/mean_gpu.cu b/src/ops/mean_gpu.cu
index 5125924c9..a57a679a6 100644
--- a/src/ops/mean_gpu.cu
+++ b/src/ops/mean_gpu.cu
@@ -15,6 +15,7 @@ namespace ctranslate2 {
const cuda::index_t outer_size,
const cuda::index_t axis_size,
const cuda::index_t inner_size,
+ const bool get_sum,
T* output) {
typedef cub::BlockReduce BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
@@ -30,7 +31,9 @@ namespace ctranslate2 {
AccumT sum = BlockReduce(temp_storage).Sum(thread_sum);
if (threadIdx.x == 0) {
- output[blockIdx.x] = sum / AccumT(axis_size);
+ output[blockIdx.x] = sum;
+ if (!get_sum)
+ output[blockIdx.x] /= AccumT(axis_size);
}
}
@@ -39,6 +42,7 @@ namespace ctranslate2 {
const dim_t outer_size,
const dim_t axis_size,
const dim_t inner_size,
+ const bool get_sum,
StorageView& output) const {
const dim_t blocks = std::min(outer_size * inner_size, cuda::max_blocks);
mean_kernel, float><<>>(
@@ -46,6 +50,7 @@ namespace ctranslate2 {
outer_size,
axis_size,
inner_size,
+ get_sum,
cuda::device_cast(output.data()));
}
@@ -55,6 +60,7 @@ namespace ctranslate2 {
const dim_t outer_size, \
const dim_t axis_size, \
const dim_t inner_size, \
+ const bool get_sum, \
StorageView& output) const;
DECLARE_IMPL(float)
diff --git a/src/ops/sum.cc b/src/ops/sum.cc
new file mode 100644
index 000000000..bbbc5a8bc
--- /dev/null
+++ b/src/ops/sum.cc
@@ -0,0 +1,44 @@
+#include "ctranslate2/ops/sum.h"
+
+#include "dispatch.h"
+namespace ctranslate2 {
+ namespace ops {
+
+ Sum::Sum(const dim_t axis)
+ : Mean(axis)
+ {
+ }
+
+ void Sum::operator()(const StorageView& input, StorageView& output) const {
+ PROFILE("Sum");
+
+ const dim_t axis = _axis < 0 ? input.rank() + _axis : _axis;
+ if (axis >= input.rank())
+ throw std::out_of_range("Cannot compute sum of axis " + std::to_string(axis)
+ + " for input with rank " + std::to_string(input.rank()));
+
+ const dim_t axis_size = input.dim(axis);
+ if (axis_size == 1) {
+ output = input;
+ return;
+ }
+
+ {
+ Shape output_shape(input.shape());
+ output_shape[axis] = 1;
+ output.resize(std::move(output_shape));
+ }
+
+ dim_t inner_size = 1;
+ dim_t outer_size = 1;
+ for (dim_t i = 0; i < axis; ++i)
+ outer_size *= input.dim(i);
+ for (dim_t i = axis + 1; i < input.rank(); ++i)
+ inner_size *= input.dim(i);
+
+ DEVICE_AND_FLOAT_DISPATCH("Sum", input.device(), input.dtype(),
+ (compute(input, outer_size, axis_size, inner_size, true, output)));
+ }
+
+ }
+}