@@ -320,6 +321,9 @@ python run_quantization.py ./config/llama/ptq_argument.json
# GPTQ 量化启动命令参考
python run_quantization.py ./config/llama/gptq_argument.json
+# FOEM Quantization Command Reference
+python run_quantization.py ./config/llama/foem_argument.json
+
# W8A8C8(INT)量化启动命令参考
python run_quantization.py ./config/llama/ptq_c8_argument.json
diff --git a/llm/config/llama/foem_argument.json b/llm/config/llama/foem_argument.json
new file mode 100644
index 000000000000..16b83e43266c
--- /dev/null
+++ b/llm/config/llama/foem_argument.json
@@ -0,0 +1,18 @@
+{
+ "model_name_or_path": "meta-llama/Meta-Llama-3-8B",
+ "per_device_train_batch_size": 8,
+ "per_device_eval_batch_size": 8,
+ "eval_accumulation_steps":16,
+ "src_length": 1024,
+ "max_length": 2048,
+ "bf16": false,
+ "fp16_opt_level": "O0",
+ "dataset_name_or_path": "./data",
+ "output_dir": "./checkpoints/gptq_ckpts",
+ "do_eval": true,
+ "eval_with_do_generation": false,
+ "do_gptq": true,
+ "do_foem": true,
+ "unified_checkpoint": true,
+ "gptq_step": 8
+ }
\ No newline at end of file
diff --git a/llm/config/llama/irqlora_argument.json b/llm/config/llama/irqlora_argument.json
new file mode 100644
index 000000000000..63526eb450af
--- /dev/null
+++ b/llm/config/llama/irqlora_argument.json
@@ -0,0 +1,35 @@
+{
+ "model_name_or_path": "facebook/llama-7b",
+ "dataset_name_or_path": "./data",
+ "output_dir": "./checkpoints/ir_qlora_ckpts",
+ "per_device_train_batch_size": 4,
+ "gradient_accumulation_steps": 4,
+ "per_device_eval_batch_size": 8,
+ "eval_accumulation_steps":16,
+ "num_train_epochs": 3,
+ "learning_rate": 3e-04,
+ "warmup_steps": 30,
+ "logging_steps": 1,
+ "evaluation_strategy": "epoch",
+ "save_strategy": "epoch",
+ "src_length": 1024,
+ "max_length": 2048,
+ "bf16": true,
+ "fp16_opt_level": "O2",
+ "do_train": true,
+ "do_eval": true,
+ "disable_tqdm": true,
+ "load_best_model_at_end": true,
+ "eval_with_do_generation": false,
+ "metric_for_best_model": "accuracy",
+ "recompute": true,
+ "save_total_limit": 1,
+ "tensor_parallel_degree": 1,
+ "pipeline_parallel_degree": 1,
+ "lora": true,
+ "zero_padding": false,
+ "use_flash_attention": false,
+ "unified_checkpoint": true,
+ "weight_quantize_algo": "nf4",
+ "ir_qlora": true
+}
\ No newline at end of file
diff --git a/llm/docs/finetune.md b/llm/docs/finetune.md
index 5cc07822f1b7..f64fcf4bba1f 100644
--- a/llm/docs/finetune.md
+++ b/llm/docs/finetune.md
@@ -29,6 +29,8 @@
- QLoRA:量化感知低秩适配(Quantized Low-Rank Adaptation)与标准 LoRA 相比,它可额外减少多达33%的内存使用,使其在 GPU 内存受限的情况下尤为有用。QLoRA 通常比普通 LoRA 多花费约20%的时间,但其显著的内存节省使其在 GPU 内存有限的情况下成为唯一可行的选择。
+- IR-QLoRA:量化感知低秩适配(Accurate LoRA-Finetuning Quantization of LLMs via Information Retention, ICML 2024 Oral)与标准 QLoRA 相比,通过准确的信息保留进一步地提高了精度。
+
## 3. 快速开始
@@ -81,6 +83,9 @@ python run_finetune.py ./config/llama/lora_argument.json
# 单卡QLoRA
python run_finetune.py ./config/llama/qlora_argument.json
+# 单卡IR-QLoRA
+python run_finetune.py ./config/llama/irqlora_argument.json
+
# 多卡LoRA
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py ./config/llama/lora_argument.json
diff --git a/llm/docs/quantization.md b/llm/docs/quantization.md
index 322e4d9b582e..b0028e338f85 100644
--- a/llm/docs/quantization.md
+++ b/llm/docs/quantization.md
@@ -8,6 +8,7 @@
新增 PieceWiseSearch 参数搜索算法并将算法扩展至**所有线性层**,对模型权重和激活分布进行调整,减少后续 A8W8 PTQ 量化损失。
- **GPTQ**。[GPTQ](https://arxiv.org/abs/2210.17323)是业界主流的权重量化算法,可以将大模型权重进行4位整数无损量化,提高模型推理速度。
- **AWQ**。[AWQ](https://arxiv.org/abs/2306.00978)是业界主流的权重量化算法,可以将大模型权重进行4位整数无损量化,提高模型推理速度。
+- **FOEM**。[FOEM](https://arxiv.org/abs/2507.11017)是一种考虑一阶误差补偿的量化算法,相比GPTQ能够更进一步地减小量化误差。

@@ -94,7 +95,13 @@ python run_quantization.py ./config/llama/ptq_c8_argument.json
python run_quantization.py ./config/llama/fp8_ptq_argument.json
```
-### 2.8 量化参数介绍
+### 2.8 FOEM 量化
+
+```shell
+python run_quantization.py ./config/llama/foem_argument.json
+```
+
+### 2.9 量化参数介绍
量化参数(QuantArgument)
@@ -130,6 +137,7 @@ python run_quantization.py ./config/llama/fp8_ptq_argument.json
- `load_quant_model`: 是否加载量化模型,默认为 False。用于验证量化后的模型效果, 若设为 True,则从 output_dir 中加载权重。启动该过程需要设`do_ptq`为 False。如果量化时使用了 smooth 或 shift,加载时需要保持相同的配置(shift_step/search_step 可设为8)。注意,当前该函数只支持 pdparams 格式加载,若要使用该功能,设置`"unified_checkpoint": false`。
- `skip_list_names`: 需要量化跳过的层名称列表,默认为空列表。可以使用层名的部分字符串作为匹配,如['down_proj']表示跳过所有 ffn2层。
- `do_gptq`: 是否进行 GPTQ 量化,GPTQ 对模型进行 WINT4量化,相比于普通 PTQ 量化精度更高,量化时间较长。默认为 False。
+- `do_foem`: 是否进行 FOEM 量化,FOEM 对模型进行 WINT4量化,相比于普通 GPTQ 量化精度更高,量化时间较长。默认为 False。
- `gptq_step`: GPTQ 量化步数,也即模型前向次数,默认为8。
- `do_awq`: 是否进行 AWQ 量化,AWQ 对模型进行 WINT4量化,相比于普通 PTQ 量化精度更高。默认为 False。
- `auto_clip`: AWQ 时是否进行自动搜索截断值并对模型权重进行截断操作,截断操作有利于量化模型精度,但搜索速度较慢。默认为 False。
diff --git a/llm/run_finetune.py b/llm/run_finetune.py
index 31427a516f2d..3e340f3c7012 100644
--- a/llm/run_finetune.py
+++ b/llm/run_finetune.py
@@ -591,6 +591,12 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config,
else:
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)
+ if model_args.ir_qlora:
+ from paddlenlp.quantization.irqlora_utils import get_my_model
+ model2 = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
+ model = get_my_model(model, model2)
+ del model2
+
model.print_trainable_parameters()
if model_args.lokr:
diff --git a/llm/utils/quant.py b/llm/utils/quant.py
index c2459994ffb3..da26370a5d3e 100644
--- a/llm/utils/quant.py
+++ b/llm/utils/quant.py
@@ -481,7 +481,7 @@ def apply_gptq(quant_args, trainer, ptq_dataloader):
description="GPTQ",
max_eval_iters=quant_args.gptq_step,
)
- cur_quant_layer.fasterquant(percdamp=0.1, groupsize=-1, actorder=True)
+ cur_quant_layer.fasterquant(percdamp=0.1, groupsize=-1, actorder=True, do_foem=quant_args.do_foem, foem_beta=quant_args.foem_beta)
del cur_quant_layer
setattr(parent_layer, sub_name, cur_layer)
logger.info("***** GPTQ done *****")
diff --git a/paddlenlp/quantization/irqlora_utils.py b/paddlenlp/quantization/irqlora_utils.py
new file mode 100644
index 000000000000..e507987692c7
--- /dev/null
+++ b/paddlenlp/quantization/irqlora_utils.py
@@ -0,0 +1,166 @@
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import operator
+import numpy as np
+from paddlenlp.peft.lora.lora_quantization_layers import QuantizationLoRALinear, QuantizationLoRABaseLinear
+from functools import reduce # Required in Python 3
+from scipy.stats import norm
+from paddleslim.lc.quantizers.quant_func import create_dynamic_map
+from paddleslim.lc.layers.linear import Linear4bit
+from paddlenlp.quantization.qlora import qlora_weight_dequantize, qlora_weight_quantize
+
+cache_folder_path = ''
+module_num = 0
+sigma = 1 / norm.ppf(paddle.linspace(0.9677083, 0.5, 9)[:-1]).tolist()[0]
+
+def get_my_model(model, model_fp, blocksize2=256, tau_range=0.1, tau_n=25):
+ model.model = _replace_with_ours_lora_4bit_linear(model.model, model_fp=model_fp, blocksize2=blocksize2, tau_range=tau_range, tau_n=tau_n)
+ return model
+
+def prod(iterable):
+ return reduce(operator.mul, iterable, 1)
+
+normal_map_fp8 = create_dynamic_map()
+normal_map_fp8 = paddle.to_tensor(normal_map_fp8)
+def quantize_tensor(X, L, idx=False):
+ X_shape = X.shape
+ X_expanded = X.reshape([-1, 1])
+ L_reshaped = L.reshape([1, -1])
+ abs_diff = paddle.abs(X_expanded - L_reshaped)
+ min_index = paddle.argmin(abs_diff, axis=-1)
+ min_index = min_index.cast("uint8").reshape(X_shape)
+ return min_index
+
+def dequantize_tensor(X, L):
+ return paddle.index_select(L, axis=0, index=paddle.to_tensor(X, dtype=paddle.int32).reshape([-1])).reshape(X.shape)
+
+@paddle.no_grad()
+def nf4_quant(weight, weight_shape, tau, state, quant_algo):
+ weight = weight.reshape([-1, 256, 64])
+ tau = tau.reshape([-1, 256, 1])
+ _weight = (weight - tau).reshape(weight_shape)
+ nf4_weight = qlora_weight_quantize(_weight.cuda(), quant_algo, double_quant=True)
+ tau2 = tau.abs().max(axis=1, keepdim=True)[0]
+ tau1 = quantize_tensor(tau / tau2, normal_map_fp8)
+ return nf4_weight, tau1.reshape([-1, 256]), tau2.reshape([-1, 1])
+
+@paddle.no_grad()
+def evaluate_entropy(weight_int8, blocksize):
+ _weight_int8 = weight_int8.reshape([-1, 1])
+ weight_nf4 = paddle.concat((_weight_int8//16, _weight_int8 & paddle.to_tensor(15).cast('uint8')), 1).reshape([1, -1, blocksize])
+ weight_nf4_repeat = weight_nf4.cast('int32').tile([16, 1, 1])
+ values = paddle.to_tensor(list(range(16))).reshape([16, 1, 1])
+ freqs = (weight_nf4_repeat==values.cast('int32')).sum(axis=-1, keepdim=True) / blocksize
+ entropy = -freqs * paddle.log2(freqs)
+ entropy = paddle.where(paddle.isnan(entropy), 0., entropy)
+ entropy = entropy.sum(axis=0)
+ return entropy
+
+@paddle.no_grad()
+def search(fp_weight: paddle.Tensor, fp_weight_shape, state, quant_algo, tau_range=0.1, tau_n=25, blocksize=64, blocksize2=256):
+ fp_weight = fp_weight.reshape([-1, blocksize2, blocksize])
+ tau0 = fp_weight.median(2, keepdim=True)[0] # [-1, 256, 1]
+ absmax = (fp_weight - tau0).abs().max(2, keepdim=True)[0]
+
+ entropy_max, factor_best = None, None
+ for factor in np.linspace(-tau_range*sigma, tau_range*sigma, tau_n*2+1):
+ tau = factor * absmax + tau0
+ nf4_weight, _, _ = nf4_quant(fp_weight, fp_weight_shape, tau, state, quant_algo)
+ entropy = evaluate_entropy(nf4_weight['quant_weight'], blocksize)
+
+ if entropy_max is None:
+ entropy_max = entropy
+ factor_best = paddle.full_like(entropy, factor)
+ else:
+ factor_best = paddle.where(entropy > entropy_max, factor, factor_best)
+ entropy_max = paddle.maximum(entropy_max, entropy)
+
+ tau = factor_best.reshape([-1, 256, 1]).cast('float32') * absmax + tau0
+ nf4_weight, tau1, tau2 = nf4_quant(fp_weight, fp_weight_shape, tau, state, quant_algo)
+ return nf4_weight, tau1, tau2
+
+class IRQuantizationLoRALinear(QuantizationLoRALinear):
+ def __init__(
+ self, old_model, model_fp=None, blocksize2=256, tau_range=0.1, tau_n=51
+ ):
+ for key, value in old_model.__dict__.items():
+ setattr(self, key, value)
+
+ fp_weight = model_fp.weight.data.contiguous().cpu()
+ fp_weight_shape = fp_weight.shape
+
+ quant_weight, quant_dtype, quantization_config, weight_quantize_algo, dtype, quant_scale, quant_state = self.quant_weight, self.quant_dtype, self.quantization_config, self.weight_quantize_algo, self._dtype, self.quant_scale, None
+
+ del model_fp
+
+ nf4_weight, tau1, tau2 = search(
+ fp_weight=fp_weight,
+ fp_weight_shape=fp_weight_shape,
+ state=quant_state,
+ quant_algo=weight_quantize_algo,
+ tau_range=tau_range, tau_n=tau_n,
+ blocksize2=blocksize2
+ )
+
+ self.quant_weight.data = nf4_weight['quant_weight']
+ self.qquant_scale = nf4_weight['qquant_scale']
+ self.double_quant_scale = nf4_weight['double_quant_scale']
+ self.quant_sacle_offset = nf4_weight['quant_sacle_offset']
+ self.tau_quant, self.tau_absmax = tau1, tau2
+ self.fp_weight_shape = fp_weight_shape
+
+ del fp_weight, nf4_weight
+
+ a = paddle.zeros(1)
+ self.lora_default_A_scale = paddle.create_parameter(shape=a.shape, dtype=a.dtype, default_initializer=nn.initializer.Assign(a))
+ self.lora_default_B_scale = paddle.create_parameter(shape=a.shape, dtype=a.dtype, default_initializer=nn.initializer.Assign(a))
+
+ def forward(self, x: paddle.Tensor):
+
+ with paddle.no_grad():
+ fp_B = qlora_weight_dequantize(self.quant_weight, self.weight_quantize_algo, (self.qquant_scale, self.double_quant_scale.cast('float32'), self.quant_sacle_offset.cast('float32')), double_quant=True)
+ tau = (dequantize_tensor(self.tau_quant, normal_map_fp8).reshape([-1, 256, 1]) * self.tau_absmax.reshape([-1, 1, 1]))
+ blocksize = paddle.prod(paddle.to_tensor(fp_B.shape)) // paddle.prod(paddle.to_tensor(tau.shape))
+ fp_B = (fp_B.reshape([-1, blocksize.item()]) + tau.reshape([-1, 1])).reshape(self.fp_weight_shape)
+
+ result = paddle.nn.functional.linear(x, fp_B, self.bias)
+
+ if not self.disable_lora:
+ x1 = self.lora_dropout(x)
+ x2 = x1 @ self.lora_A + self.lora_default_A_scale * x.reshape([_ for _ in x.shape[:-1]] + [self.lora_A.shape[-1]] + [-1]).mean(axis=-1)
+ x3 = ((x2 @ self.lora_B).reshape([_ for _ in x2.shape] + [-1]) + (self.lora_default_B_scale * x2.unsqueeze(-1))).reshape([_ for _ in x2.shape[:-1]] + [-1])
+ result += x3 * self.scaling
+
+ return result
+
+def _replace_with_ours_lora_4bit_linear(
+ model: nn.Layer, current_key_name=None, model_fp=None, blocksize2=256, tau_range=0.5, tau_n=51
+):
+ assert isinstance(model_fp, nn.Layer)
+ for name, module in model.named_children():
+ if current_key_name is None:
+ current_key_name = []
+ current_key_name.append(name)
+
+ if isinstance(module, QuantizationLoRALinear):
+ _modules = dict(model_fp.named_sublayers())
+
+ print(name)
+ new_layer = IRQuantizationLoRALinear(dict(model.named_sublayers())[name], model_fp=dict(model_fp.named_sublayers())[name], blocksize2=blocksize2, tau_range=tau_range, tau_n=tau_n)
+ setattr(model, name, new_layer)
+
+ if len(list(module.children())) > 0:
+ _modules = dict(model_fp.named_sublayers())
+ if name in _modules.keys():
+ _ = _replace_with_ours_lora_4bit_linear(
+ module,
+ current_key_name, _modules[name], blocksize2, tau_range, tau_n
+ )
+ else:
+ _ = _replace_with_ours_lora_4bit_linear(
+ module,
+ current_key_name, None, blocksize2, tau_range, tau_n
+ )
+ current_key_name.pop(-1)
+ return model
diff --git a/paddlenlp/trl/model_config.py b/paddlenlp/trl/model_config.py
index 2e244d211158..3a165c900637 100644
--- a/paddlenlp/trl/model_config.py
+++ b/paddlenlp/trl/model_config.py
@@ -136,6 +136,7 @@ class ModelConfig:
"help": "Block size for quant_scale of weight quant_scale(Only available for nf4 or fp4 quant_scale.)"
},
)
+ ir_qlora: bool = field(default=False, metadata={"help": "Whether to use IR-QLoRA"})
apply_hadamard: bool = field(default=False, metadata={"help": "Whether to apply hadamard"})
hadamard_block_size: int = field(default=32, metadata={"help": "hadamard block size"})
quant_input_grad: bool = field(default=False, metadata={"help": "Whether to quantize input grad"})
diff --git a/paddlenlp/trl/quant_config.py b/paddlenlp/trl/quant_config.py
index a21481f70107..a66e0b405c0f 100644
--- a/paddlenlp/trl/quant_config.py
+++ b/paddlenlp/trl/quant_config.py
@@ -104,6 +104,9 @@ class QuantConfig:
# GPTQ related parameters
do_gptq: bool = field(default=False, metadata={"help": "Whether to use GPTQ"})
gptq_step: int = field(default=8, metadata={"help": "Step for GPTQ"})
+ # FOEM related parameters
+ do_foem: bool = field(default=False, metadata={"help": "Whether to use FOEM"})
+ foem_beta: int = field(default=0.1, metadata={"help": "beta for FOEM"})
# AWQ related parameters, default for WINT4
do_awq: bool = field(default=False, metadata={"help": "Whether to use AWQ Search"})