Skip to content

Commit

Permalink
Awq 4 bit quantization support (#2508)
Browse files Browse the repository at this point in the history
* add awq linear from AutoAWQ and/or llm-awq
* add generic converter for llama-like models from HF with or without awq quantization
  • Loading branch information
vince62s authored Nov 23, 2023
1 parent 3d4c8de commit b2629b6
Show file tree
Hide file tree
Showing 6 changed files with 582 additions and 16 deletions.
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Install `OpenNMT-py` from `pip`:
pip install OpenNMT-py
```

or from the sources:
or from the source:
```bash
git clone https://github.com/OpenNMT/OpenNMT-py.git
cd OpenNMT-py
Expand Down Expand Up @@ -107,6 +107,21 @@ When using `max_relative_positions > 0` or Alibi `max_relative_positions=-2` Ope

flash attention and `F.scaled_dot_product_attention` are a bit faster and saves some GPU memory.


AWQ:

If you want to run inference or quantize an AWQ model you will need llm-awq and/or AutoAWQ.

For [llm-awq](https://github.com/mit-han-lab/llm-awq):
git clone https://github.com/mit-han-lab/llm-awq
cd llm-awq
pip install -e .
cd ..

For [AutoAWQ](https://github.com/casper-hansen/AutoAWQ):
pip install autoawq


## Documentation & FAQs

[Full HTML Documentation](https://opennmt.net/OpenNMT-py/quickstart.html)
Expand Down
25 changes: 13 additions & 12 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,15 +899,16 @@ def forward(self, tgt, enc_out=None, step=None, **kwargs):

def _init_cache(self, tgt=None):
for layer in self.transformer_layers:
if isinstance(layer.self_attn, AverageAttention):
raise NotImplementedError
else:
layer.self_attn.layer_cache = (
True,
{
"keys": torch.tensor([], device=tgt.device),
"values": torch.tensor([], device=tgt.device),
},
)
if hasattr(layer.self_attn, "rope"):
layer.self_attn.rope = layer.self_attn.rope.to(tgt.device)
if hasattr(layer, "self_attn"):
if isinstance(layer.self_attn, AverageAttention):
raise NotImplementedError
else:
layer.self_attn.layer_cache = (
True,
{
"keys": torch.tensor([], device=tgt.device),
"values": torch.tensor([], device=tgt.device),
},
)
if hasattr(layer.self_attn, "rope"):
layer.self_attn.rope = layer.self_attn.rope.to(tgt.device)
24 changes: 22 additions & 2 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,13 @@ def load_test_model(opt, device_id=0, model_path=None):

model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])

model_opt.quant_layers = opt.quant_layers
model_opt.quant_type = opt.quant_type
if hasattr(model_opt, "quant_type") and model_opt.quant_type not in [
"llm_awq",
"aawq_gemm",
"aawq_gemv",
]:
model_opt.quant_layers = opt.quant_layers
model_opt.quant_type = opt.quant_type

if opt.world_size > 1 and opt.parallel_mode == "tensor_parallel":
model_opt.world_size = opt.world_size
Expand Down Expand Up @@ -304,6 +309,21 @@ def build_base_model(model_opt, vocabs):
model = replace_bnb_linear(
model, module_to_convert=nonlora_to_quant, q_type=model_opt.quant_type
)
elif model_opt.quant_type in ["llm_awq", "aawq_gemm", "aawq_gemv"]:
logger.info(
"%s compression of layer %s" % (model_opt.quant_type, nonlora_to_quant)
)
try:
from onmt.modules.awq_linear import replace_awq_linear
except ImportError:
raise ImportError("Install llm-awq/AutoAWQ to use awq quantized model")
model = replace_awq_linear(
model,
module_to_convert=nonlora_to_quant,
w_bit=model_opt.w_bit,
group_size=model_opt.group_size,
q_type=model_opt.quant_type,
)
else:
logger.info("compression type %s not supported." % model_opt.quant_type)

Expand Down
38 changes: 38 additions & 0 deletions onmt/modules/awq_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch.nn as nn


def replace_awq_linear(
model, module_to_convert=[], w_bit=4, group_size=128, q_type="llm_awq"
):
if q_type == "llm_awq":
try:
from awq.quantize.qmodule import WQLinear
except ImportError:
raise ImportError("Install llm-awq to use awq")
AWQLin = WQLinear
elif q_type in ["aawq_gemm", "aawq_gemv"]:
try:
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
except ImportError:
raise ImportError("Install AutoAWQ to use awq")
if q_type == "aawq_gemm":
AWQLin = WQLinear_GEMM
else:
AWQLin = WQLinear_GEMV
else:
raise ValueError("No Awq framework for this value")

for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_awq_linear(module, module_to_convert, w_bit, group_size, q_type)

if isinstance(module, nn.Linear) and name in module_to_convert:
model._modules[name] = AWQLin(
w_bit=w_bit,
group_size=group_size,
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
dev=module.weight.device,
)
return model
18 changes: 17 additions & 1 deletion onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,10 +1565,26 @@ def _add_quant_opts(parser):
"--quant_type",
"-quant_type",
default="bnb_8bit",
choices=["bnb_8bit", "bnb_FP4", "bnb_NF4"],
choices=["bnb_8bit", "bnb_FP4", "bnb_NF4", "llm_awq", "aawq_gemm", "aawq_gemv"],
type=str,
help="Type of compression.",
)
group.add(
"--w_bit",
"-w_bit",
type=int,
default=4,
choices=[4],
help="W_bit quantization.",
)
group.add(
"--group_size",
"-group_size",
default=128,
choices=[128],
type=int,
help="group size quantization.",
)


def train_opts(parser):
Expand Down
Loading

0 comments on commit b2629b6

Please sign in to comment.