Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/en/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ Large model quantization reduces 16-bit and 32-bit floating-point model paramete

- **PTQ**. The self-developed adaptive LLM.PTQ quantization algorithm by the PaddleSlim team builds upon [SmoothQuant](https://arxiv.org/abs/2211.10438) and [Outlier Suppression+](https://arxiv.org/abs/2304.09145), adding the PieceWiseSearch parameter search algorithm to adjust model weight and activation distributions, reducing subsequent A8W8 PTQ quantization loss.
- **GPTQ**. [GPTQ](https://arxiv.org/abs/2210.17323) is a mainstream weight quantization algorithm that enables lossless 4-bit integer quantization of large model weights to improve inference speed.
- **FOEM**. [FOEM](https://arxiv.org/abs/2507.11017) is a PTQ method that explicitly incorporates first-order gradient terms to improve quantization error compensation.

<div align="center">
<img width="500" alt="llm" src="https://github.com/PaddlePaddle/PaddleNLP/assets/37530985/969b62db-9692-4d50-b91a-85cff305d153">
Expand All @@ -339,6 +340,9 @@ python run_quantization.py ./config/llama/ptq_argument.json
# GPTQ Quantization Command Reference
python run_quantization.py ./config/llama/gptq_argument.json

# FOEM Quantization Command Reference
python run_quantization.py ./config/llama/foem_argument.json

# W8A8C8(INT) Quantization Command Reference
python run_quantization.py ./config/llama/ptq_c8_argument.json

Expand Down
4 changes: 4 additions & 0 deletions llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" ./llm/tools/

- **PTQ**。PaddleSlim 团队自研的自适应 LLM.PTQ 量化算法,在[SmoothQuant](https://arxiv.org/abs/2211.10438)和[Outlier Suppression+](https://arxiv.org/abs/2304.09145)基础上新增 PieceWiseSearch 参数搜索算法,对模型权重和激活分布进行调整,减少后续 A8W8 PTQ 量化损失。
- **GPTQ**。[GPTQ](https://arxiv.org/abs/2210.17323)是业界主流的权重量化算法,可以将大模型权重进行4位整数无损量化,提高模型推理速度。
- **FOEM**。[FOEM](https://arxiv.org/abs/2507.11017)是一种考虑一阶误差补偿的量化算法,相比GPTQ能够更进一步地减小量化误差。

<div align="center">
<img width="500" alt="llm" src="https://github.com/PaddlePaddle/PaddleNLP/assets/37530985/969b62db-9692-4d50-b91a-85cff305d153">
Expand All @@ -320,6 +321,9 @@ python run_quantization.py ./config/llama/ptq_argument.json
# GPTQ 量化启动命令参考
python run_quantization.py ./config/llama/gptq_argument.json

# FOEM Quantization Command Reference
python run_quantization.py ./config/llama/foem_argument.json

# W8A8C8(INT)量化启动命令参考
python run_quantization.py ./config/llama/ptq_c8_argument.json

Expand Down
18 changes: 18 additions & 0 deletions llm/config/llama/foem_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"model_name_or_path": "meta-llama/Meta-Llama-3-8B",
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"src_length": 1024,
"max_length": 2048,
"bf16": false,
"fp16_opt_level": "O0",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/gptq_ckpts",
"do_eval": true,
"eval_with_do_generation": false,
"do_gptq": true,
"do_foem": true,
"unified_checkpoint": true,
"gptq_step": 8
}
35 changes: 35 additions & 0 deletions llm/config/llama/irqlora_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"model_name_or_path": "facebook/llama-7b",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/ir_qlora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"lora": true,
"zero_padding": false,
"use_flash_attention": false,
"unified_checkpoint": true,
"weight_quantize_algo": "nf4",
"ir_qlora": true
}
5 changes: 5 additions & 0 deletions llm/docs/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

- QLoRA:量化感知低秩适配(Quantized Low-Rank Adaptation)与标准 LoRA 相比,它可额外减少多达33%的内存使用,使其在 GPU 内存受限的情况下尤为有用。QLoRA 通常比普通 LoRA 多花费约20%的时间,但其显著的内存节省使其在 GPU 内存有限的情况下成为唯一可行的选择。

- IR-QLoRA:量化感知低秩适配(Accurate LoRA-Finetuning Quantization of LLMs via Information Retention, ICML 2024 Oral)与标准 QLoRA 相比,通过准确的信息保留进一步地提高了精度。


## 3. 快速开始

Expand Down Expand Up @@ -81,6 +83,9 @@ python run_finetune.py ./config/llama/lora_argument.json
# 单卡QLoRA
python run_finetune.py ./config/llama/qlora_argument.json

# 单卡IR-QLoRA
python run_finetune.py ./config/llama/irqlora_argument.json

# 多卡LoRA
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py ./config/llama/lora_argument.json

Expand Down
10 changes: 9 additions & 1 deletion llm/docs/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
新增 PieceWiseSearch 参数搜索算法并将算法扩展至**所有线性层**,对模型权重和激活分布进行调整,减少后续 A8W8 PTQ 量化损失。
- **GPTQ**。[GPTQ](https://arxiv.org/abs/2210.17323)是业界主流的权重量化算法,可以将大模型权重进行4位整数无损量化,提高模型推理速度。
- **AWQ**。[AWQ](https://arxiv.org/abs/2306.00978)是业界主流的权重量化算法,可以将大模型权重进行4位整数无损量化,提高模型推理速度。
- **FOEM**。[FOEM](https://arxiv.org/abs/2507.11017)是一种考虑一阶误差补偿的量化算法,相比GPTQ能够更进一步地减小量化误差。

<div align="center">
<img width="800" alt="llm" src="https://github.com/PaddlePaddle/PaddleNLP/assets/63761690/fe8f941b-4b35-48ca-814f-96533d7e24ce">
Expand Down Expand Up @@ -94,7 +95,13 @@ python run_quantization.py ./config/llama/ptq_c8_argument.json
python run_quantization.py ./config/llama/fp8_ptq_argument.json
```

### 2.8 量化参数介绍
### 2.8 FOEM 量化

```shell
python run_quantization.py ./config/llama/foem_argument.json
```

### 2.9 量化参数介绍

<summary>&emsp; 量化参数(QuantArgument)</summary>

Expand Down Expand Up @@ -130,6 +137,7 @@ python run_quantization.py ./config/llama/fp8_ptq_argument.json
- `load_quant_model`: 是否加载量化模型,默认为 False。用于验证量化后的模型效果, 若设为 True,则从 output_dir 中加载权重。启动该过程需要设`do_ptq`为 False。如果量化时使用了 smooth 或 shift,加载时需要保持相同的配置(shift_step/search_step 可设为8)。注意,当前该函数只支持 pdparams 格式加载,若要使用该功能,设置`"unified_checkpoint": false`。
- `skip_list_names`: 需要量化跳过的层名称列表,默认为空列表。可以使用层名的部分字符串作为匹配,如['down_proj']表示跳过所有 ffn2层。
- `do_gptq`: 是否进行 GPTQ 量化,GPTQ 对模型进行 WINT4量化,相比于普通 PTQ 量化精度更高,量化时间较长。默认为 False。
- `do_foem`: 是否进行 FOEM 量化,FOEM 对模型进行 WINT4量化,相比于普通 GPTQ 量化精度更高,量化时间较长。默认为 False。
- `gptq_step`: GPTQ 量化步数,也即模型前向次数,默认为8。
- `do_awq`: 是否进行 AWQ 量化,AWQ 对模型进行 WINT4量化,相比于普通 PTQ 量化精度更高。默认为 False。
- `auto_clip`: AWQ 时是否进行自动搜索截断值并对模型权重进行截断操作,截断操作有利于量化模型精度,但搜索速度较慢。默认为 False。
Expand Down
6 changes: 6 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,12 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config,
else:
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)

if model_args.ir_qlora:
from paddlenlp.quantization.irqlora_utils import get_my_model
model2 = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
model = get_my_model(model, model2)
del model2

model.print_trainable_parameters()

if model_args.lokr:
Expand Down
2 changes: 1 addition & 1 deletion llm/utils/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def apply_gptq(quant_args, trainer, ptq_dataloader):
description="GPTQ",
max_eval_iters=quant_args.gptq_step,
)
cur_quant_layer.fasterquant(percdamp=0.1, groupsize=-1, actorder=True)
cur_quant_layer.fasterquant(percdamp=0.1, groupsize=-1, actorder=True, do_foem=quant_args.do_foem, foem_beta=quant_args.foem_beta)
del cur_quant_layer
setattr(parent_layer, sub_name, cur_layer)
logger.info("***** GPTQ done *****")
Expand Down
166 changes: 166 additions & 0 deletions paddlenlp/quantization/irqlora_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import operator
import numpy as np
from paddlenlp.peft.lora.lora_quantization_layers import QuantizationLoRALinear, QuantizationLoRABaseLinear
from functools import reduce # Required in Python 3
from scipy.stats import norm
from paddleslim.lc.quantizers.quant_func import create_dynamic_map
from paddleslim.lc.layers.linear import Linear4bit
from paddlenlp.quantization.qlora import qlora_weight_dequantize, qlora_weight_quantize

cache_folder_path = ''
module_num = 0
sigma = 1 / norm.ppf(paddle.linspace(0.9677083, 0.5, 9)[:-1]).tolist()[0]

def get_my_model(model, model_fp, blocksize2=256, tau_range=0.1, tau_n=25):
model.model = _replace_with_ours_lora_4bit_linear(model.model, model_fp=model_fp, blocksize2=blocksize2, tau_range=tau_range, tau_n=tau_n)
return model

def prod(iterable):
return reduce(operator.mul, iterable, 1)

normal_map_fp8 = create_dynamic_map()
normal_map_fp8 = paddle.to_tensor(normal_map_fp8)
def quantize_tensor(X, L, idx=False):
X_shape = X.shape
X_expanded = X.reshape([-1, 1])
L_reshaped = L.reshape([1, -1])
abs_diff = paddle.abs(X_expanded - L_reshaped)
min_index = paddle.argmin(abs_diff, axis=-1)
min_index = min_index.cast("uint8").reshape(X_shape)
return min_index

def dequantize_tensor(X, L):
return paddle.index_select(L, axis=0, index=paddle.to_tensor(X, dtype=paddle.int32).reshape([-1])).reshape(X.shape)

@paddle.no_grad()
def nf4_quant(weight, weight_shape, tau, state, quant_algo):
weight = weight.reshape([-1, 256, 64])
tau = tau.reshape([-1, 256, 1])
_weight = (weight - tau).reshape(weight_shape)
nf4_weight = qlora_weight_quantize(_weight.cuda(), quant_algo, double_quant=True)
tau2 = tau.abs().max(axis=1, keepdim=True)[0]
tau1 = quantize_tensor(tau / tau2, normal_map_fp8)
return nf4_weight, tau1.reshape([-1, 256]), tau2.reshape([-1, 1])

@paddle.no_grad()
def evaluate_entropy(weight_int8, blocksize):
_weight_int8 = weight_int8.reshape([-1, 1])
weight_nf4 = paddle.concat((_weight_int8//16, _weight_int8 & paddle.to_tensor(15).cast('uint8')), 1).reshape([1, -1, blocksize])
weight_nf4_repeat = weight_nf4.cast('int32').tile([16, 1, 1])
values = paddle.to_tensor(list(range(16))).reshape([16, 1, 1])
freqs = (weight_nf4_repeat==values.cast('int32')).sum(axis=-1, keepdim=True) / blocksize
entropy = -freqs * paddle.log2(freqs)
entropy = paddle.where(paddle.isnan(entropy), 0., entropy)
entropy = entropy.sum(axis=0)
return entropy

@paddle.no_grad()
def search(fp_weight: paddle.Tensor, fp_weight_shape, state, quant_algo, tau_range=0.1, tau_n=25, blocksize=64, blocksize2=256):
fp_weight = fp_weight.reshape([-1, blocksize2, blocksize])
tau0 = fp_weight.median(2, keepdim=True)[0] # [-1, 256, 1]
absmax = (fp_weight - tau0).abs().max(2, keepdim=True)[0]

entropy_max, factor_best = None, None
for factor in np.linspace(-tau_range*sigma, tau_range*sigma, tau_n*2+1):
tau = factor * absmax + tau0
nf4_weight, _, _ = nf4_quant(fp_weight, fp_weight_shape, tau, state, quant_algo)
entropy = evaluate_entropy(nf4_weight['quant_weight'], blocksize)

if entropy_max is None:
entropy_max = entropy
factor_best = paddle.full_like(entropy, factor)
else:
factor_best = paddle.where(entropy > entropy_max, factor, factor_best)
entropy_max = paddle.maximum(entropy_max, entropy)

tau = factor_best.reshape([-1, 256, 1]).cast('float32') * absmax + tau0
nf4_weight, tau1, tau2 = nf4_quant(fp_weight, fp_weight_shape, tau, state, quant_algo)
return nf4_weight, tau1, tau2

class IRQuantizationLoRALinear(QuantizationLoRALinear):
def __init__(
self, old_model, model_fp=None, blocksize2=256, tau_range=0.1, tau_n=51
):
for key, value in old_model.__dict__.items():
setattr(self, key, value)

fp_weight = model_fp.weight.data.contiguous().cpu()
fp_weight_shape = fp_weight.shape

quant_weight, quant_dtype, quantization_config, weight_quantize_algo, dtype, quant_scale, quant_state = self.quant_weight, self.quant_dtype, self.quantization_config, self.weight_quantize_algo, self._dtype, self.quant_scale, None

del model_fp

nf4_weight, tau1, tau2 = search(
fp_weight=fp_weight,
fp_weight_shape=fp_weight_shape,
state=quant_state,
quant_algo=weight_quantize_algo,
tau_range=tau_range, tau_n=tau_n,
blocksize2=blocksize2
)

self.quant_weight.data = nf4_weight['quant_weight']
self.qquant_scale = nf4_weight['qquant_scale']
self.double_quant_scale = nf4_weight['double_quant_scale']
self.quant_sacle_offset = nf4_weight['quant_sacle_offset']
self.tau_quant, self.tau_absmax = tau1, tau2
self.fp_weight_shape = fp_weight_shape

del fp_weight, nf4_weight

a = paddle.zeros(1)
self.lora_default_A_scale = paddle.create_parameter(shape=a.shape, dtype=a.dtype, default_initializer=nn.initializer.Assign(a))
self.lora_default_B_scale = paddle.create_parameter(shape=a.shape, dtype=a.dtype, default_initializer=nn.initializer.Assign(a))

def forward(self, x: paddle.Tensor):

with paddle.no_grad():
fp_B = qlora_weight_dequantize(self.quant_weight, self.weight_quantize_algo, (self.qquant_scale, self.double_quant_scale.cast('float32'), self.quant_sacle_offset.cast('float32')), double_quant=True)
tau = (dequantize_tensor(self.tau_quant, normal_map_fp8).reshape([-1, 256, 1]) * self.tau_absmax.reshape([-1, 1, 1]))
blocksize = paddle.prod(paddle.to_tensor(fp_B.shape)) // paddle.prod(paddle.to_tensor(tau.shape))
fp_B = (fp_B.reshape([-1, blocksize.item()]) + tau.reshape([-1, 1])).reshape(self.fp_weight_shape)

result = paddle.nn.functional.linear(x, fp_B, self.bias)

if not self.disable_lora:
x1 = self.lora_dropout(x)
x2 = x1 @ self.lora_A + self.lora_default_A_scale * x.reshape([_ for _ in x.shape[:-1]] + [self.lora_A.shape[-1]] + [-1]).mean(axis=-1)
x3 = ((x2 @ self.lora_B).reshape([_ for _ in x2.shape] + [-1]) + (self.lora_default_B_scale * x2.unsqueeze(-1))).reshape([_ for _ in x2.shape[:-1]] + [-1])
result += x3 * self.scaling

return result

def _replace_with_ours_lora_4bit_linear(
model: nn.Layer, current_key_name=None, model_fp=None, blocksize2=256, tau_range=0.5, tau_n=51
):
assert isinstance(model_fp, nn.Layer)
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)

if isinstance(module, QuantizationLoRALinear):
_modules = dict(model_fp.named_sublayers())

print(name)
new_layer = IRQuantizationLoRALinear(dict(model.named_sublayers())[name], model_fp=dict(model_fp.named_sublayers())[name], blocksize2=blocksize2, tau_range=tau_range, tau_n=tau_n)
setattr(model, name, new_layer)

if len(list(module.children())) > 0:
_modules = dict(model_fp.named_sublayers())
if name in _modules.keys():
_ = _replace_with_ours_lora_4bit_linear(
module,
current_key_name, _modules[name], blocksize2, tau_range, tau_n
)
else:
_ = _replace_with_ours_lora_4bit_linear(
module,
current_key_name, None, blocksize2, tau_range, tau_n
)
current_key_name.pop(-1)
return model
1 change: 1 addition & 0 deletions paddlenlp/trl/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class ModelConfig:
"help": "Block size for quant_scale of weight quant_scale(Only available for nf4 or fp4 quant_scale.)"
},
)
ir_qlora: bool = field(default=False, metadata={"help": "Whether to use IR-QLoRA"})
apply_hadamard: bool = field(default=False, metadata={"help": "Whether to apply hadamard"})
hadamard_block_size: int = field(default=32, metadata={"help": "hadamard block size"})
quant_input_grad: bool = field(default=False, metadata={"help": "Whether to quantize input grad"})
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/trl/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class QuantConfig:
# GPTQ related parameters
do_gptq: bool = field(default=False, metadata={"help": "Whether to use GPTQ"})
gptq_step: int = field(default=8, metadata={"help": "Step for GPTQ"})
# FOEM related parameters
do_foem: bool = field(default=False, metadata={"help": "Whether to use FOEM"})
foem_beta: int = field(default=0.1, metadata={"help": "beta for FOEM"})

# AWQ related parameters, default for WINT4
do_awq: bool = field(default=False, metadata={"help": "Whether to use AWQ Search"})
Expand Down
Loading